@@ -137,28 +137,29 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
137
137
# Get model-specific generation parameters
138
138
model_params = get_model_generation_params (model_manager .current_model )
139
139
140
- # Update with request parameters
140
+ # Update with request parameters and optimized defaults for high-quality responses
141
141
generation_params = {
142
142
"max_new_tokens" : request .max_tokens ,
143
143
"temperature" : request .temperature ,
144
- "top_p" : request .top_p ,
145
- "top_k" : request .top_k ,
146
- "repetition_penalty" : request .repetition_penalty ,
144
+ "top_p" : request .top_p if request . top_p is not None else 0.92 , # Optimized default
145
+ "top_k" : request .top_k if request . top_k is not None else 80 , # Optimized default
146
+ "repetition_penalty" : request .repetition_penalty if request . repetition_penalty is not None else 1.15 , # Optimized default
147
147
"do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
148
148
}
149
149
150
150
# Merge model-specific params with request params
151
+ # This ensures we get the best of both worlds - model-specific optimizations
152
+ # and our high-quality parameters
151
153
generation_params .update (model_params )
152
154
153
- # Generate text - properly await the async call
155
+ # Generate text with optimized parameters - properly await the async call
154
156
generated_text = await model_manager .generate_text (
155
157
prompt = request .prompt ,
156
158
system_prompt = request .system_prompt ,
157
159
** generation_params
158
160
)
159
161
160
162
# Additional cleanup for any special tokens that might have slipped through
161
- import re
162
163
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
163
164
cleaned_text = re .sub (special_token_pattern , '' , generated_text )
164
165
@@ -172,6 +173,22 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
172
173
cleaned_text = cleaned_text [:marker_pos ]
173
174
break
174
175
176
+ # Check for repetition patterns that indicate the model is stuck
177
+ if len (cleaned_text ) > 200 :
178
+ # Look for repeating patterns of 20+ characters that repeat 3+ times
179
+ for pattern_len in range (20 , 40 ):
180
+ if pattern_len < len (cleaned_text ) // 3 :
181
+ for i in range (len (cleaned_text ) - pattern_len * 3 ):
182
+ pattern = cleaned_text [i :i + pattern_len ]
183
+ if pattern and not pattern .isspace ():
184
+ if cleaned_text [i :].count (pattern ) >= 3 :
185
+ # Found a repeating pattern, truncate at the second occurrence
186
+ second_pos = cleaned_text .find (pattern , i + pattern_len )
187
+ if second_pos > 0 :
188
+ logger .info (f"Detected repetition pattern in text generation, truncating response" )
189
+ cleaned_text = cleaned_text [:second_pos + pattern_len ]
190
+ break
191
+
175
192
return GenerationResponse (
176
193
text = cleaned_text ,
177
194
model = model_manager .current_model
@@ -203,27 +220,28 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
203
220
# Get model-specific generation parameters
204
221
model_params = get_model_generation_params (model_manager .current_model )
205
222
206
- # Prepare generation parameters
223
+ # Prepare generation parameters with optimized defaults for high-quality responses
207
224
generation_params = {
208
225
"max_new_tokens" : request .max_tokens ,
209
226
"temperature" : request .temperature ,
210
- "top_p" : request .top_p ,
211
- "top_k" : request .top_k ,
212
- "repetition_penalty" : request .repetition_penalty ,
227
+ "top_p" : request .top_p if request . top_p is not None else 0.92 , # Optimized default
228
+ "top_k" : request .top_k if request . top_k is not None else 80 , # Optimized default
229
+ "repetition_penalty" : request .repetition_penalty if request . repetition_penalty is not None else 1.15 , # Optimized default
213
230
"do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
214
231
}
215
232
216
233
# Merge model-specific params with request params
234
+ # This ensures we get the best of both worlds - model-specific optimizations
235
+ # and our high-quality parameters
217
236
generation_params .update (model_params )
218
237
219
- # Generate completion
238
+ # Generate completion with optimized parameters
220
239
generated_text = await model_manager .generate_text (
221
240
prompt = formatted_prompt ,
222
241
** generation_params
223
242
)
224
243
225
244
# Additional cleanup for any special tokens that might have slipped through
226
- import re
227
245
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
228
246
cleaned_text = re .sub (special_token_pattern , '' , generated_text )
229
247
@@ -237,6 +255,22 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
237
255
cleaned_text = cleaned_text [:marker_pos ]
238
256
break
239
257
258
+ # Check for repetition patterns that indicate the model is stuck
259
+ if len (cleaned_text ) > 200 :
260
+ # Look for repeating patterns of 20+ characters that repeat 3+ times
261
+ for pattern_len in range (20 , 40 ):
262
+ if pattern_len < len (cleaned_text ) // 3 :
263
+ for i in range (len (cleaned_text ) - pattern_len * 3 ):
264
+ pattern = cleaned_text [i :i + pattern_len ]
265
+ if pattern and not pattern .isspace ():
266
+ if cleaned_text [i :].count (pattern ) >= 3 :
267
+ # Found a repeating pattern, truncate at the second occurrence
268
+ second_pos = cleaned_text .find (pattern , i + pattern_len )
269
+ if second_pos > 0 :
270
+ logger .info (f"Detected repetition pattern in chat completion, truncating response" )
271
+ cleaned_text = cleaned_text [:second_pos + pattern_len ]
272
+ break
273
+
240
274
# Format response with cleaned text
241
275
return ChatResponse (
242
276
choices = [{
@@ -388,7 +422,7 @@ async def stream_chat(
388
422
@router .post ("/generate/batch" , response_model = BatchGenerationResponse )
389
423
async def batch_generate (request : BatchGenerationRequest ) -> BatchGenerationResponse :
390
424
"""
391
- Generate text for multiple prompts in a single request
425
+ Generate high-quality text for multiple prompts in a single request
392
426
"""
393
427
if not model_manager .current_model :
394
428
raise HTTPException (status_code = 400 , detail = "No model is currently loaded" )
@@ -397,29 +431,31 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
397
431
# Get model-specific generation parameters
398
432
model_params = get_model_generation_params (model_manager .current_model )
399
433
400
- # Update with request parameters
434
+ # Update with request parameters and optimized defaults for high-quality responses
401
435
generation_params = {
402
436
"max_new_tokens" : request .max_tokens ,
403
437
"temperature" : request .temperature ,
404
- "top_p" : request .top_p ,
405
- "top_k" : request .top_k ,
406
- "repetition_penalty" : request .repetition_penalty ,
438
+ "top_p" : request .top_p if request . top_p is not None else 0.92 , # Optimized default
439
+ "top_k" : request .top_k if request . top_k is not None else 80 , # Optimized default
440
+ "repetition_penalty" : request .repetition_penalty if request . repetition_penalty is not None else 1.15 , # Optimized default
407
441
"do_sample" : model_params .get ("do_sample" , True ) # Pass do_sample from model params
408
442
}
409
443
410
444
# Merge model-specific params with request params
445
+ # This ensures we get the best of both worlds - model-specific optimizations
446
+ # and our high-quality parameters
411
447
generation_params .update (model_params )
412
448
413
449
responses = []
414
450
for prompt in request .prompts :
451
+ # Generate text with optimized parameters
415
452
generated_text = await model_manager .generate_text (
416
453
prompt = prompt ,
417
454
system_prompt = request .system_prompt ,
418
455
** generation_params
419
456
)
420
457
421
458
# Additional cleanup for any special tokens that might have slipped through
422
- import re
423
459
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
424
460
cleaned_text = re .sub (special_token_pattern , '' , generated_text )
425
461
@@ -433,6 +469,22 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
433
469
cleaned_text = cleaned_text [:marker_pos ]
434
470
break
435
471
472
+ # Check for repetition patterns that indicate the model is stuck
473
+ if len (cleaned_text ) > 200 :
474
+ # Look for repeating patterns of 20+ characters that repeat 3+ times
475
+ for pattern_len in range (20 , 40 ):
476
+ if pattern_len < len (cleaned_text ) // 3 :
477
+ for i in range (len (cleaned_text ) - pattern_len * 3 ):
478
+ pattern = cleaned_text [i :i + pattern_len ]
479
+ if pattern and not pattern .isspace ():
480
+ if cleaned_text [i :].count (pattern ) >= 3 :
481
+ # Found a repeating pattern, truncate at the second occurrence
482
+ second_pos = cleaned_text .find (pattern , i + pattern_len )
483
+ if second_pos > 0 :
484
+ logger .info (f"Detected repetition pattern in batch generation, truncating response" )
485
+ cleaned_text = cleaned_text [:second_pos + pattern_len ]
486
+ break
487
+
436
488
responses .append (cleaned_text )
437
489
438
490
return BatchGenerationResponse (responses = responses )
0 commit comments