Skip to content

Commit 2f754f9

Browse files
committed
Improved AI Model Response Quality
1 parent f48ad67 commit 2f754f9

File tree

7 files changed

+358
-132
lines changed

7 files changed

+358
-132
lines changed

client/python_client/locallab_client/client.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(self, base_url: str, timeout: float = 30.0, headers: Dict[str, str]
143143

144144
class LocalLabClient:
145145
"""Asynchronous client for the LocalLab API with improved error handling."""
146-
146+
147147
# Class-level attribute for activity tracking
148148
_last_activity_times = {}
149149

@@ -152,7 +152,7 @@ def __init__(self, config: Union[str, LocalLabConfig, Dict[str, Any]]):
152152
config = LocalLabConfig(base_url=config)
153153
elif isinstance(config, dict):
154154
config = LocalLabConfig(**config)
155-
155+
156156
self.config = config
157157
self._session: Optional[aiohttp.ClientSession] = None
158158
self.ws: Optional[websockets.WebSocketClientProtocol] = None
@@ -167,7 +167,7 @@ async def connect(self):
167167
"""Initialize HTTP session with improved error handling."""
168168
if self._closed:
169169
raise RuntimeError("Client is closed")
170-
170+
171171
if not self._session:
172172
try:
173173
headers = {
@@ -176,7 +176,7 @@ async def connect(self):
176176
}
177177
if self.config.api_key:
178178
headers["Authorization"] = f"Bearer {self.config.api_key}"
179-
179+
180180
self._session = aiohttp.ClientSession(
181181
headers=headers,
182182
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
@@ -291,20 +291,27 @@ async def stream_generate(
291291
max_length: Optional[int] = None,
292292
temperature: float = 0.7,
293293
top_p: float = 0.9,
294-
timeout: float = 120.0, # Increased timeout for low-resource CPUs
295-
retry_count: int = 2 # Add retry count for reliability
294+
timeout: float = 300.0, # Increased timeout for more complete responses (5 minutes)
295+
retry_count: int = 3, # Increased retry count for better reliability
296+
repetition_penalty: float = 1.15 # Increased repetition penalty for better quality
296297
) -> AsyncGenerator[str, None]:
297298
"""Stream text generation with token-level streaming and robust error handling"""
298299
# Update activity timestamp
299300
self._update_activity()
300301

302+
# Use a higher max_length by default to ensure complete responses
303+
if max_length is None:
304+
max_length = 4096 # Default to 4096 tokens for more complete responses
305+
301306
payload = {
302307
"prompt": prompt,
303308
"model_id": model_id,
304309
"stream": True,
305310
"max_length": max_length,
306311
"temperature": temperature,
307-
"top_p": top_p
312+
"top_p": top_p,
313+
# Add repetition_penalty for better quality
314+
"repetition_penalty": 1.1
308315
}
309316

310317
# Create a timeout for this specific request
@@ -313,6 +320,7 @@ async def stream_generate(
313320
# Track retries
314321
retries = 0
315322
last_error = None
323+
accumulated_text = "" # Track accumulated text for error recovery
316324

317325
while retries <= retry_count:
318326
try:
@@ -329,12 +337,15 @@ async def stream_generate(
329337
# Track if we've seen any data to detect early disconnections
330338
received_data = False
331339
# Buffer for accumulating partial responses if needed
340+
token_buffer = ""
341+
last_token_time = time.time()
332342

333343
try:
334344
# Process the streaming response
335345
async for line in response.content:
336346
if line:
337347
received_data = True
348+
current_time = time.time()
338349
text = line.decode("utf-8").strip()
339350

340351
# Skip empty lines
@@ -347,18 +358,30 @@ async def stream_generate(
347358

348359
# Check for end of stream marker
349360
if text == "[DONE]":
361+
# If we have any buffered text, yield it before ending
362+
if token_buffer:
363+
yield token_buffer
350364
break
351365

352366
# Check for error messages
353367
if text.startswith("\nError:") or text.startswith("Error:"):
354368
error_msg = text.replace("\nError: ", "").replace("Error: ", "")
355369
raise Exception(error_msg)
356370

371+
# Add to accumulated text for error recovery
372+
accumulated_text += text
373+
374+
# Reset the last token time
375+
last_token_time = current_time
376+
377+
# Yield the token directly for immediate feedback
357378
yield text
358379

359380
# If we didn't receive any data, the stream might have ended unexpectedly
360381
if not received_data:
361-
yield "\nError: Stream ended unexpectedly without returning any data"
382+
# If we have accumulated text from previous attempts, don't report an error
383+
if not accumulated_text:
384+
yield "\nError: Stream ended unexpectedly without returning any data"
362385

363386
# Successful completion, break the retry loop
364387
break

client/python_client/locallab_client/sync_client.py

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _ensure_connection(self):
7373
def _run_coroutine(self, coro, timeout: Optional[float] = None):
7474
"""Run a coroutine in the event loop thread with timeout and error handling."""
7575
self._ensure_connection()
76-
76+
7777
try:
7878
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
7979
return future.result(timeout=timeout)
@@ -123,10 +123,10 @@ def close(self):
123123
# Clean up
124124
self._loop = None
125125
self._thread = None
126-
126+
127127
# Shutdown executor
128128
self._executor.shutdown(wait=False)
129-
129+
130130
except Exception as e:
131131
logger.error(f"Error during client cleanup: {str(e)}")
132132
finally:
@@ -139,25 +139,40 @@ def generate(
139139
stream: bool = False,
140140
max_length: Optional[int] = None,
141141
temperature: float = 0.7,
142-
top_p: float = 0.9
142+
top_p: float = 0.9,
143+
repetition_penalty: float = 1.15, # Increased repetition penalty for better quality
144+
top_k: int = 80 # Added top_k parameter for better quality
143145
) -> Union[str, Generator[str, None, None]]:
144146
"""
145-
Generate text using the model.
147+
Generate text using the model with improved quality settings.
146148
147149
Args:
148150
prompt: The prompt to generate text from
149151
model_id: Optional model ID to use
150152
stream: Whether to stream the response
151-
max_length: Maximum length of the generated text
153+
max_length: Maximum length of the generated text (defaults to 1024 if None)
152154
temperature: Temperature for sampling
153155
top_p: Top-p for nucleus sampling
156+
repetition_penalty: Penalty for repetition (higher values = less repetition)
154157
155158
Returns:
156159
If stream=False, returns the generated text as a string.
157160
If stream=True, returns a generator that yields chunks of text.
158161
"""
162+
# Use a higher max_length by default to ensure complete responses
163+
if max_length is None:
164+
max_length = 4096 # Default to 4096 tokens for more complete responses
165+
159166
if stream:
160-
return self.stream_generate(prompt, model_id, max_length, temperature, top_p)
167+
return self.stream_generate(
168+
prompt=prompt,
169+
model_id=model_id,
170+
max_length=max_length,
171+
temperature=temperature,
172+
top_p=top_p,
173+
repetition_penalty=repetition_penalty,
174+
top_k=top_k
175+
)
161176

162177
return self._run_coroutine(
163178
self._async_client.generate(
@@ -166,7 +181,10 @@ def generate(
166181
stream=False,
167182
max_length=max_length,
168183
temperature=temperature,
169-
top_p=top_p
184+
top_p=top_p,
185+
repetition_penalty=repetition_penalty,
186+
top_k=top_k,
187+
timeout=180.0 # Increased timeout for more complete responses (3 minutes)
170188
)
171189
)
172190

@@ -177,22 +195,29 @@ def stream_generate(
177195
max_length: Optional[int] = None,
178196
temperature: float = 0.7,
179197
top_p: float = 0.9,
180-
timeout: float = 60.0
198+
timeout: float = 300.0, # Increased timeout for more complete responses (5 minutes)
199+
repetition_penalty: float = 1.15, # Increased repetition penalty for better quality
200+
top_k: int = 80 # Added top_k parameter for better quality
181201
) -> Generator[str, None, None]:
182202
"""
183-
Stream text generation.
203+
Stream text generation with improved quality and reliability.
184204
185205
Args:
186206
prompt: The prompt to generate text from
187207
model_id: Optional model ID to use
188-
max_length: Maximum length of the generated text
208+
max_length: Maximum length of the generated text (defaults to 1024 if None)
189209
temperature: Temperature for sampling
190210
top_p: Top-p for nucleus sampling
191211
timeout: Request timeout in seconds
212+
repetition_penalty: Penalty for repetition (higher values = less repetition)
192213
193214
Returns:
194215
A generator that yields chunks of text as they are generated.
195216
"""
217+
# Use a higher max_length by default to ensure complete responses
218+
if max_length is None:
219+
max_length = 4096 # Default to 4096 tokens for more complete responses
220+
196221
# Create a queue to pass data between the async and sync worlds
197222
queue = asyncio.Queue()
198223
stop_event = threading.Event()
@@ -206,7 +231,10 @@ async def producer():
206231
max_length=max_length,
207232
temperature=temperature,
208233
top_p=top_p,
209-
timeout=timeout
234+
timeout=timeout,
235+
retry_count=3, # Increased retry count for better reliability
236+
repetition_penalty=repetition_penalty, # Pass the repetition penalty parameter
237+
top_k=top_k # Pass the top_k parameter
210238
):
211239
await queue.put(chunk)
212240

@@ -250,25 +278,41 @@ def chat(
250278
stream: bool = False,
251279
max_length: Optional[int] = None,
252280
temperature: float = 0.7,
253-
top_p: float = 0.9
281+
top_p: float = 0.9,
282+
repetition_penalty: float = 1.15, # Increased repetition penalty for better quality
283+
top_k: int = 80 # Added top_k parameter for better quality
254284
) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
255285
"""
256-
Chat completion.
286+
Chat completion with improved quality settings.
257287
258288
Args:
259289
messages: List of message dictionaries with 'role' and 'content' keys
260290
model_id: Optional model ID to use
261291
stream: Whether to stream the response
262-
max_length: Maximum length of the generated text
292+
max_length: Maximum length of the generated text (defaults to 1024 if None)
263293
temperature: Temperature for sampling
264294
top_p: Top-p for nucleus sampling
295+
repetition_penalty: Penalty for repetition (higher values = less repetition)
265296
266297
Returns:
267298
If stream=False, returns the chat completion response.
268299
If stream=True, returns a generator that yields chunks of the response.
269300
"""
301+
# Use a higher max_length by default to ensure complete responses
302+
if max_length is None:
303+
max_length = 4096 # Default to 4096 tokens for more complete responses
304+
270305
if stream:
271-
return self.stream_chat(messages, model_id, max_length, temperature, top_p)
306+
return self.stream_chat(
307+
messages=messages,
308+
model_id=model_id,
309+
max_length=max_length,
310+
temperature=temperature,
311+
top_p=top_p,
312+
timeout=300.0, # Increased timeout for more complete responses (5 minutes)
313+
repetition_penalty=repetition_penalty,
314+
top_k=top_k
315+
)
272316

273317
return self._run_coroutine(
274318
self._async_client.chat(
@@ -277,7 +321,10 @@ def chat(
277321
stream=False,
278322
max_length=max_length,
279323
temperature=temperature,
280-
top_p=top_p
324+
top_p=top_p,
325+
timeout=180.0, # Increased timeout for more complete responses (3 minutes)
326+
repetition_penalty=repetition_penalty,
327+
top_k=top_k
281328
)
282329
)
283330

@@ -287,21 +334,29 @@ def stream_chat(
287334
model_id: Optional[str] = None,
288335
max_length: Optional[int] = None,
289336
temperature: float = 0.7,
290-
top_p: float = 0.9
337+
top_p: float = 0.9,
338+
timeout: float = 300.0, # Increased timeout for more complete responses (5 minutes)
339+
repetition_penalty: float = 1.15, # Added repetition penalty for better quality
340+
top_k: int = 80 # Added top_k parameter for better quality
291341
) -> Generator[Dict[str, Any], None, None]:
292342
"""
293-
Stream chat completion.
343+
Stream chat completion with improved quality and reliability.
294344
295345
Args:
296346
messages: List of message dictionaries with 'role' and 'content' keys
297347
model_id: Optional model ID to use
298-
max_length: Maximum length of the generated text
348+
max_length: Maximum length of the generated text (defaults to 1024 if None)
299349
temperature: Temperature for sampling
300350
top_p: Top-p for nucleus sampling
351+
timeout: Request timeout in seconds
301352
302353
Returns:
303354
A generator that yields chunks of the chat completion response.
304355
"""
356+
# Use a higher max_length by default to ensure complete responses
357+
if max_length is None:
358+
max_length = 4096 # Default to 4096 tokens for more complete responses
359+
305360
# Create a queue to pass data between the async and sync worlds
306361
queue = asyncio.Queue()
307362
stop_event = threading.Event()
@@ -314,7 +369,11 @@ async def producer():
314369
model_id=model_id,
315370
max_length=max_length,
316371
temperature=temperature,
317-
top_p=top_p
372+
top_p=top_p,
373+
timeout=timeout,
374+
retry_count=3, # Increased retry count for better reliability
375+
repetition_penalty=repetition_penalty,
376+
top_k=top_k
318377
):
319378
await queue.put(chunk)
320379

@@ -357,28 +416,39 @@ def batch_generate(
357416
model_id: Optional[str] = None,
358417
max_length: Optional[int] = None,
359418
temperature: float = 0.7,
360-
top_p: float = 0.9
419+
top_p: float = 0.9,
420+
repetition_penalty: float = 1.15, # Increased repetition penalty for better quality
421+
top_k: int = 80, # Added top_k parameter for better quality
422+
timeout: float = 300.0 # Added timeout parameter (5 minutes)
361423
) -> Dict[str, List[str]]:
362424
"""
363-
Generate text for multiple prompts in parallel.
425+
Generate text for multiple prompts in parallel with improved quality settings.
364426
365427
Args:
366428
prompts: List of prompts to generate text from
367429
model_id: Optional model ID to use
368-
max_length: Maximum length of the generated text
430+
max_length: Maximum length of the generated text (defaults to 1024 if None)
369431
temperature: Temperature for sampling
370432
top_p: Top-p for nucleus sampling
433+
repetition_penalty: Penalty for repetition (higher values = less repetition)
371434
372435
Returns:
373436
Dictionary with the generated responses.
374437
"""
438+
# Use a higher max_length by default to ensure complete responses
439+
if max_length is None:
440+
max_length = 4096 # Default to 4096 tokens for more complete responses
441+
375442
return self._run_coroutine(
376443
self._async_client.batch_generate(
377444
prompts=prompts,
378445
model_id=model_id,
379446
max_length=max_length,
380447
temperature=temperature,
381-
top_p=top_p
448+
top_p=top_p,
449+
repetition_penalty=repetition_penalty,
450+
top_k=top_k,
451+
timeout=timeout # Use the provided timeout parameter
382452
)
383453
)
384454

0 commit comments

Comments
 (0)