@@ -35,6 +35,7 @@ class GenerationRequest(BaseModel):
35
35
top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
36
36
top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
37
37
repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
38
+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
38
39
system_prompt : Optional [str ] = Field (default = DEFAULT_SYSTEM_INSTRUCTIONS )
39
40
stream : bool = Field (default = False )
40
41
@@ -47,6 +48,7 @@ class BatchGenerationRequest(BaseModel):
47
48
top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
48
49
top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
49
50
repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
51
+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
50
52
system_prompt : Optional [str ] = Field (default = DEFAULT_SYSTEM_INSTRUCTIONS )
51
53
52
54
@@ -64,6 +66,7 @@ class ChatRequest(BaseModel):
64
66
top_p : float = Field (default = DEFAULT_TOP_P , ge = 0.0 , le = 1.0 )
65
67
top_k : int = Field (default = 80 , ge = 1 , le = 1000 ) # Added top_k parameter
66
68
repetition_penalty : float = Field (default = 1.15 , ge = 1.0 , le = 2.0 ) # Added repetition_penalty parameter
69
+ max_time : Optional [float ] = Field (default = None , ge = 0.0 , description = "Maximum time in seconds for generation" )
67
70
stream : bool = Field (default = False )
68
71
69
72
@@ -129,7 +132,7 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
129
132
# Return a streaming response
130
133
return StreamingResponse (
131
134
generate_stream (request .prompt , request .max_tokens , request .temperature ,
132
- request .top_p , request .system_prompt ),
135
+ request .top_p , request .system_prompt , request . max_time ),
133
136
media_type = "text/event-stream"
134
137
)
135
138
@@ -144,7 +147,8 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
144
147
"top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
145
148
"top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
146
149
"repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
147
- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
150
+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
151
+ "max_time" : request .max_time # Pass max_time parameter
148
152
}
149
153
150
154
# Merge model-specific params with request params
@@ -212,7 +216,7 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
212
216
# If streaming is requested, return a streaming response
213
217
if request .stream :
214
218
return StreamingResponse (
215
- stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p ),
219
+ stream_chat (formatted_prompt , request .max_tokens , request .temperature , request .top_p , request . max_time ),
216
220
media_type = "text/event-stream"
217
221
)
218
222
@@ -227,7 +231,8 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
227
231
"top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
228
232
"top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
229
233
"repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
230
- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
234
+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
235
+ "max_time" : request .max_time # Pass max_time parameter
231
236
}
232
237
233
238
# Merge model-specific params with request params
@@ -292,7 +297,8 @@ async def generate_stream(
292
297
max_tokens : int ,
293
298
temperature : float ,
294
299
top_p : float ,
295
- system_prompt : Optional [str ]
300
+ system_prompt : Optional [str ],
301
+ max_time : Optional [float ] = None
296
302
) -> AsyncGenerator [str , None ]:
297
303
"""
298
304
Generate text in a streaming fashion and return as server-sent events
@@ -309,7 +315,8 @@ async def generate_stream(
309
315
"top_p" : top_p ,
310
316
"top_k" : 80 , # Optimized top_k for high-quality streaming
311
317
"repetition_penalty" : 1.15 , # Optimized repetition_penalty for high-quality streaming
312
- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
318
+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
319
+ "max_time" : max_time # Pass max_time parameter
313
320
}
314
321
315
322
# Merge model-specific params with request params
@@ -361,7 +368,8 @@ async def stream_chat(
361
368
formatted_prompt : str ,
362
369
max_tokens : int ,
363
370
temperature : float ,
364
- top_p : float
371
+ top_p : float ,
372
+ max_time : Optional [float ] = None
365
373
) -> AsyncGenerator [str , None ]:
366
374
"""
367
375
Stream chat completion responses as server-sent events
@@ -378,7 +386,8 @@ async def stream_chat(
378
386
"top_p" : top_p ,
379
387
"top_k" : 80 , # Optimized top_k for high-quality streaming
380
388
"repetition_penalty" : 1.15 , # Optimized repetition_penalty for high-quality streaming
381
- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
389
+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
390
+ "max_time" : max_time # Pass max_time parameter
382
391
}
383
392
384
393
# Merge model-specific params with request params
@@ -438,7 +447,8 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
438
447
"top_p" : request .top_p if request .top_p is not None else 0.92 , # Optimized default
439
448
"top_k" : request .top_k if request .top_k is not None else 80 , # Optimized default
440
449
"repetition_penalty" : request .repetition_penalty if request .repetition_penalty is not None else 1.15 , # Optimized default
441
- "do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
450
+ "do_sample" : model_params .get ("do_sample" , True ), # Pass do_sample from model params
451
+ "max_time" : request .max_time # Pass max_time parameter
442
452
}
443
453
444
454
# Merge model-specific params with request params
0 commit comments