Skip to content

Commit 5d5cfea

Browse files
committed
Added Max Time Params Handling
1 parent 1f8200a commit 5d5cfea

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
All notable changes to LocalLab will be documented in this file.
44

5+
## [0.7.1] - 2025-05-18
6+
7+
### Fixed
8+
9+
- Fixed critical error: "ModelManager.generate() got an unexpected keyword argument 'max_time'"
10+
- Added proper handling of the `max_time` parameter in all generation endpoints
11+
- Updated `ModelManager.generate()` method to accept the `max_time` parameter
12+
- Added `max_time` parameter to all request models (GenerationRequest, BatchGenerationRequest, ChatRequest)
13+
- Ensured consistent parameter passing between client and server
14+
- Set default max_time to 180 seconds (3 minutes) when not specified
15+
- Improved error handling for generation timeouts
16+
517
## Client Package [1.1.0] - 2025-05-17
618

719
### Added

locallab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# This ensures Hugging Face's progress bars are displayed correctly
77
from .utils.early_config import configure_hf_logging
88

9-
__version__ = "0.7.0" # Improved stream generation and non-streaming generation quality
9+
__version__ = "0.7.1" # Fixed max_time parameter handling in generation endpoints
1010

1111
# Only import what's necessary initially, lazy-load the rest
1212
from .logger import get_logger

locallab/model_manager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ async def generate(
399399
top_k: Optional[int] = None,
400400
repetition_penalty: Optional[float] = None,
401401
system_instructions: Optional[str] = None,
402-
do_sample: bool = True
402+
do_sample: bool = True,
403+
max_time: Optional[float] = None
403404
) -> str:
404405
"""Generate text from the model"""
405406
# Check model timeout
@@ -527,10 +528,13 @@ async def generate(
527528
}
528529

529530
# Set a reasonable max time for generation to prevent hanging
530-
# Use the DEFAULT_MAX_TIME from config (increased to 180 seconds)
531-
if "max_time" not in generate_params and not stream:
532-
from .config import DEFAULT_MAX_TIME
533-
generate_params["max_time"] = DEFAULT_MAX_TIME # Use the default max time from config
531+
# Use the provided max_time or a default value of 180 seconds
532+
if not stream:
533+
if max_time is not None:
534+
generate_params["max_time"] = max_time
535+
elif "max_time" not in generate_params:
536+
# Default to 180 seconds (3 minutes) if not specified
537+
generate_params["max_time"] = 180.0 # Default max time in seconds
534538

535539
# Define comprehensive stop sequences for proper termination
536540
stop_sequences = [
@@ -918,7 +922,7 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
918922

919923
# Update with provided kwargs
920924
for key, value in kwargs.items():
921-
if key in ["max_length", "temperature", "top_p", "top_k", "repetition_penalty"]:
925+
if key in ["max_length", "temperature", "top_p", "top_k", "repetition_penalty", "max_time"]:
922926
gen_params[key] = value
923927
elif key == "max_new_tokens":
924928
# Handle the max_new_tokens parameter by mapping to max_length

locallab/routes/generate.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class GenerationRequest(BaseModel):
3535
top_p: float = Field(default=DEFAULT_TOP_P, ge=0.0, le=1.0)
3636
top_k: int = Field(default=80, ge=1, le=1000) # Added top_k parameter
3737
repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0) # Added repetition_penalty parameter
38+
max_time: Optional[float] = Field(default=None, ge=0.0, description="Maximum time in seconds for generation")
3839
system_prompt: Optional[str] = Field(default=DEFAULT_SYSTEM_INSTRUCTIONS)
3940
stream: bool = Field(default=False)
4041

@@ -47,6 +48,7 @@ class BatchGenerationRequest(BaseModel):
4748
top_p: float = Field(default=DEFAULT_TOP_P, ge=0.0, le=1.0)
4849
top_k: int = Field(default=80, ge=1, le=1000) # Added top_k parameter
4950
repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0) # Added repetition_penalty parameter
51+
max_time: Optional[float] = Field(default=None, ge=0.0, description="Maximum time in seconds for generation")
5052
system_prompt: Optional[str] = Field(default=DEFAULT_SYSTEM_INSTRUCTIONS)
5153

5254

@@ -64,6 +66,7 @@ class ChatRequest(BaseModel):
6466
top_p: float = Field(default=DEFAULT_TOP_P, ge=0.0, le=1.0)
6567
top_k: int = Field(default=80, ge=1, le=1000) # Added top_k parameter
6668
repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0) # Added repetition_penalty parameter
69+
max_time: Optional[float] = Field(default=None, ge=0.0, description="Maximum time in seconds for generation")
6770
stream: bool = Field(default=False)
6871

6972

@@ -129,7 +132,7 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
129132
# Return a streaming response
130133
return StreamingResponse(
131134
generate_stream(request.prompt, request.max_tokens, request.temperature,
132-
request.top_p, request.system_prompt),
135+
request.top_p, request.system_prompt, request.max_time),
133136
media_type="text/event-stream"
134137
)
135138

@@ -144,7 +147,8 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
144147
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
145148
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
146149
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
147-
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
150+
"do_sample": model_params.get("do_sample", True), # Pass do_sample from model params
151+
"max_time": request.max_time # Pass max_time parameter
148152
}
149153

150154
# Merge model-specific params with request params
@@ -212,7 +216,7 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
212216
# If streaming is requested, return a streaming response
213217
if request.stream:
214218
return StreamingResponse(
215-
stream_chat(formatted_prompt, request.max_tokens, request.temperature, request.top_p),
219+
stream_chat(formatted_prompt, request.max_tokens, request.temperature, request.top_p, request.max_time),
216220
media_type="text/event-stream"
217221
)
218222

@@ -227,7 +231,8 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
227231
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
228232
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
229233
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
230-
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
234+
"do_sample": model_params.get("do_sample", True), # Pass do_sample from model params
235+
"max_time": request.max_time # Pass max_time parameter
231236
}
232237

233238
# Merge model-specific params with request params
@@ -292,7 +297,8 @@ async def generate_stream(
292297
max_tokens: int,
293298
temperature: float,
294299
top_p: float,
295-
system_prompt: Optional[str]
300+
system_prompt: Optional[str],
301+
max_time: Optional[float] = None
296302
) -> AsyncGenerator[str, None]:
297303
"""
298304
Generate text in a streaming fashion and return as server-sent events
@@ -309,7 +315,8 @@ async def generate_stream(
309315
"top_p": top_p,
310316
"top_k": 80, # Optimized top_k for high-quality streaming
311317
"repetition_penalty": 1.15, # Optimized repetition_penalty for high-quality streaming
312-
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
318+
"do_sample": model_params.get("do_sample", True), # Pass do_sample from model params
319+
"max_time": max_time # Pass max_time parameter
313320
}
314321

315322
# Merge model-specific params with request params
@@ -361,7 +368,8 @@ async def stream_chat(
361368
formatted_prompt: str,
362369
max_tokens: int,
363370
temperature: float,
364-
top_p: float
371+
top_p: float,
372+
max_time: Optional[float] = None
365373
) -> AsyncGenerator[str, None]:
366374
"""
367375
Stream chat completion responses as server-sent events
@@ -378,7 +386,8 @@ async def stream_chat(
378386
"top_p": top_p,
379387
"top_k": 80, # Optimized top_k for high-quality streaming
380388
"repetition_penalty": 1.15, # Optimized repetition_penalty for high-quality streaming
381-
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
389+
"do_sample": model_params.get("do_sample", True), # Pass do_sample from model params
390+
"max_time": max_time # Pass max_time parameter
382391
}
383392

384393
# Merge model-specific params with request params
@@ -438,7 +447,8 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
438447
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
439448
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
440449
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
441-
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
450+
"do_sample": model_params.get("do_sample", True), # Pass do_sample from model params
451+
"max_time": request.max_time # Pass max_time parameter
442452
}
443453

444454
# Merge model-specific params with request params

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
setup(
4949
name="locallab",
50-
version="0.7.0",
50+
version="0.7.1",
5151
packages=find_packages(include=["locallab", "locallab.*"]),
5252
install_requires=install_requires,
5353
extras_require={

0 commit comments

Comments
 (0)