@@ -73,7 +73,7 @@ def _ensure_connection(self):
73
73
def _run_coroutine (self , coro , timeout : Optional [float ] = None ):
74
74
"""Run a coroutine in the event loop thread with timeout and error handling."""
75
75
self ._ensure_connection ()
76
-
76
+
77
77
try :
78
78
future = asyncio .run_coroutine_threadsafe (coro , self ._loop )
79
79
return future .result (timeout = timeout )
@@ -123,10 +123,10 @@ def close(self):
123
123
# Clean up
124
124
self ._loop = None
125
125
self ._thread = None
126
-
126
+
127
127
# Shutdown executor
128
128
self ._executor .shutdown (wait = False )
129
-
129
+
130
130
except Exception as e :
131
131
logger .error (f"Error during client cleanup: { str (e )} " )
132
132
finally :
@@ -139,25 +139,40 @@ def generate(
139
139
stream : bool = False ,
140
140
max_length : Optional [int ] = None ,
141
141
temperature : float = 0.7 ,
142
- top_p : float = 0.9
142
+ top_p : float = 0.9 ,
143
+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
144
+ top_k : int = 80 # Added top_k parameter for better quality
143
145
) -> Union [str , Generator [str , None , None ]]:
144
146
"""
145
- Generate text using the model.
147
+ Generate text using the model with improved quality settings .
146
148
147
149
Args:
148
150
prompt: The prompt to generate text from
149
151
model_id: Optional model ID to use
150
152
stream: Whether to stream the response
151
- max_length: Maximum length of the generated text
153
+ max_length: Maximum length of the generated text (defaults to 1024 if None)
152
154
temperature: Temperature for sampling
153
155
top_p: Top-p for nucleus sampling
156
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
154
157
155
158
Returns:
156
159
If stream=False, returns the generated text as a string.
157
160
If stream=True, returns a generator that yields chunks of text.
158
161
"""
162
+ # Use a higher max_length by default to ensure complete responses
163
+ if max_length is None :
164
+ max_length = 4096 # Default to 4096 tokens for more complete responses
165
+
159
166
if stream :
160
- return self .stream_generate (prompt , model_id , max_length , temperature , top_p )
167
+ return self .stream_generate (
168
+ prompt = prompt ,
169
+ model_id = model_id ,
170
+ max_length = max_length ,
171
+ temperature = temperature ,
172
+ top_p = top_p ,
173
+ repetition_penalty = repetition_penalty ,
174
+ top_k = top_k
175
+ )
161
176
162
177
return self ._run_coroutine (
163
178
self ._async_client .generate (
@@ -166,7 +181,10 @@ def generate(
166
181
stream = False ,
167
182
max_length = max_length ,
168
183
temperature = temperature ,
169
- top_p = top_p
184
+ top_p = top_p ,
185
+ repetition_penalty = repetition_penalty ,
186
+ top_k = top_k ,
187
+ timeout = 180.0 # Increased timeout for more complete responses (3 minutes)
170
188
)
171
189
)
172
190
@@ -177,22 +195,29 @@ def stream_generate(
177
195
max_length : Optional [int ] = None ,
178
196
temperature : float = 0.7 ,
179
197
top_p : float = 0.9 ,
180
- timeout : float = 60.0
198
+ timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
199
+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
200
+ top_k : int = 80 # Added top_k parameter for better quality
181
201
) -> Generator [str , None , None ]:
182
202
"""
183
- Stream text generation.
203
+ Stream text generation with improved quality and reliability .
184
204
185
205
Args:
186
206
prompt: The prompt to generate text from
187
207
model_id: Optional model ID to use
188
- max_length: Maximum length of the generated text
208
+ max_length: Maximum length of the generated text (defaults to 1024 if None)
189
209
temperature: Temperature for sampling
190
210
top_p: Top-p for nucleus sampling
191
211
timeout: Request timeout in seconds
212
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
192
213
193
214
Returns:
194
215
A generator that yields chunks of text as they are generated.
195
216
"""
217
+ # Use a higher max_length by default to ensure complete responses
218
+ if max_length is None :
219
+ max_length = 4096 # Default to 4096 tokens for more complete responses
220
+
196
221
# Create a queue to pass data between the async and sync worlds
197
222
queue = asyncio .Queue ()
198
223
stop_event = threading .Event ()
@@ -206,7 +231,10 @@ async def producer():
206
231
max_length = max_length ,
207
232
temperature = temperature ,
208
233
top_p = top_p ,
209
- timeout = timeout
234
+ timeout = timeout ,
235
+ retry_count = 3 , # Increased retry count for better reliability
236
+ repetition_penalty = repetition_penalty , # Pass the repetition penalty parameter
237
+ top_k = top_k # Pass the top_k parameter
210
238
):
211
239
await queue .put (chunk )
212
240
@@ -250,25 +278,41 @@ def chat(
250
278
stream : bool = False ,
251
279
max_length : Optional [int ] = None ,
252
280
temperature : float = 0.7 ,
253
- top_p : float = 0.9
281
+ top_p : float = 0.9 ,
282
+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
283
+ top_k : int = 80 # Added top_k parameter for better quality
254
284
) -> Union [Dict [str , Any ], Generator [Dict [str , Any ], None , None ]]:
255
285
"""
256
- Chat completion.
286
+ Chat completion with improved quality settings .
257
287
258
288
Args:
259
289
messages: List of message dictionaries with 'role' and 'content' keys
260
290
model_id: Optional model ID to use
261
291
stream: Whether to stream the response
262
- max_length: Maximum length of the generated text
292
+ max_length: Maximum length of the generated text (defaults to 1024 if None)
263
293
temperature: Temperature for sampling
264
294
top_p: Top-p for nucleus sampling
295
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
265
296
266
297
Returns:
267
298
If stream=False, returns the chat completion response.
268
299
If stream=True, returns a generator that yields chunks of the response.
269
300
"""
301
+ # Use a higher max_length by default to ensure complete responses
302
+ if max_length is None :
303
+ max_length = 4096 # Default to 4096 tokens for more complete responses
304
+
270
305
if stream :
271
- return self .stream_chat (messages , model_id , max_length , temperature , top_p )
306
+ return self .stream_chat (
307
+ messages = messages ,
308
+ model_id = model_id ,
309
+ max_length = max_length ,
310
+ temperature = temperature ,
311
+ top_p = top_p ,
312
+ timeout = 300.0 , # Increased timeout for more complete responses (5 minutes)
313
+ repetition_penalty = repetition_penalty ,
314
+ top_k = top_k
315
+ )
272
316
273
317
return self ._run_coroutine (
274
318
self ._async_client .chat (
@@ -277,7 +321,10 @@ def chat(
277
321
stream = False ,
278
322
max_length = max_length ,
279
323
temperature = temperature ,
280
- top_p = top_p
324
+ top_p = top_p ,
325
+ timeout = 180.0 , # Increased timeout for more complete responses (3 minutes)
326
+ repetition_penalty = repetition_penalty ,
327
+ top_k = top_k
281
328
)
282
329
)
283
330
@@ -287,21 +334,29 @@ def stream_chat(
287
334
model_id : Optional [str ] = None ,
288
335
max_length : Optional [int ] = None ,
289
336
temperature : float = 0.7 ,
290
- top_p : float = 0.9
337
+ top_p : float = 0.9 ,
338
+ timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
339
+ repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
340
+ top_k : int = 80 # Added top_k parameter for better quality
291
341
) -> Generator [Dict [str , Any ], None , None ]:
292
342
"""
293
- Stream chat completion.
343
+ Stream chat completion with improved quality and reliability .
294
344
295
345
Args:
296
346
messages: List of message dictionaries with 'role' and 'content' keys
297
347
model_id: Optional model ID to use
298
- max_length: Maximum length of the generated text
348
+ max_length: Maximum length of the generated text (defaults to 1024 if None)
299
349
temperature: Temperature for sampling
300
350
top_p: Top-p for nucleus sampling
351
+ timeout: Request timeout in seconds
301
352
302
353
Returns:
303
354
A generator that yields chunks of the chat completion response.
304
355
"""
356
+ # Use a higher max_length by default to ensure complete responses
357
+ if max_length is None :
358
+ max_length = 4096 # Default to 4096 tokens for more complete responses
359
+
305
360
# Create a queue to pass data between the async and sync worlds
306
361
queue = asyncio .Queue ()
307
362
stop_event = threading .Event ()
@@ -314,7 +369,11 @@ async def producer():
314
369
model_id = model_id ,
315
370
max_length = max_length ,
316
371
temperature = temperature ,
317
- top_p = top_p
372
+ top_p = top_p ,
373
+ timeout = timeout ,
374
+ retry_count = 3 , # Increased retry count for better reliability
375
+ repetition_penalty = repetition_penalty ,
376
+ top_k = top_k
318
377
):
319
378
await queue .put (chunk )
320
379
@@ -357,28 +416,39 @@ def batch_generate(
357
416
model_id : Optional [str ] = None ,
358
417
max_length : Optional [int ] = None ,
359
418
temperature : float = 0.7 ,
360
- top_p : float = 0.9
419
+ top_p : float = 0.9 ,
420
+ repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
421
+ top_k : int = 80 , # Added top_k parameter for better quality
422
+ timeout : float = 300.0 # Added timeout parameter (5 minutes)
361
423
) -> Dict [str , List [str ]]:
362
424
"""
363
- Generate text for multiple prompts in parallel.
425
+ Generate text for multiple prompts in parallel with improved quality settings .
364
426
365
427
Args:
366
428
prompts: List of prompts to generate text from
367
429
model_id: Optional model ID to use
368
- max_length: Maximum length of the generated text
430
+ max_length: Maximum length of the generated text (defaults to 1024 if None)
369
431
temperature: Temperature for sampling
370
432
top_p: Top-p for nucleus sampling
433
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
371
434
372
435
Returns:
373
436
Dictionary with the generated responses.
374
437
"""
438
+ # Use a higher max_length by default to ensure complete responses
439
+ if max_length is None :
440
+ max_length = 4096 # Default to 4096 tokens for more complete responses
441
+
375
442
return self ._run_coroutine (
376
443
self ._async_client .batch_generate (
377
444
prompts = prompts ,
378
445
model_id = model_id ,
379
446
max_length = max_length ,
380
447
temperature = temperature ,
381
- top_p = top_p
448
+ top_p = top_p ,
449
+ repetition_penalty = repetition_penalty ,
450
+ top_k = top_k ,
451
+ timeout = timeout # Use the provided timeout parameter
382
452
)
383
453
)
384
454
0 commit comments