Skip to content

Commit b1c7d10

Browse files
committed
Improved AI response generation quality
1 parent ee7a3df commit b1c7d10

File tree

2 files changed

+131
-43
lines changed

2 files changed

+131
-43
lines changed

locallab/model_manager.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -434,19 +434,27 @@ async def generate(
434434
from .config import get_model_generation_params
435435
gen_params = get_model_generation_params(self.current_model)
436436

437-
# Set balanced defaults for quality and speed
437+
# Set optimized defaults for high-quality responses
438438
if not max_length and not max_new_tokens:
439-
# Use a reasonable default max_length that balances quality and speed
439+
# Use a higher default max_length for more complete, high-quality responses
440440
# Don't limit it too much to ensure complete responses
441-
gen_params["max_length"] = min(gen_params.get("max_length", DEFAULT_MAX_LENGTH), 1024)
441+
gen_params["max_length"] = min(gen_params.get("max_length", DEFAULT_MAX_LENGTH), 4096)
442442

443443
if not temperature:
444-
# Use a balanced temperature for good quality responses
444+
# Use a balanced temperature for high-quality responses
445445
gen_params["temperature"] = gen_params.get("temperature", DEFAULT_TEMPERATURE)
446446

447447
if not top_k:
448-
# Add top_k for better quality sampling
449-
gen_params["top_k"] = 50
448+
# Use a higher top_k for better quality sampling
449+
gen_params["top_k"] = 80 # Increased from 50 to 80 for better quality
450+
451+
if not top_p:
452+
# Use a higher top_p for better quality
453+
gen_params["top_p"] = 0.92 # Increased for better quality
454+
455+
if not repetition_penalty:
456+
# Use a higher repetition_penalty for better quality
457+
gen_params["repetition_penalty"] = 1.15 # Increased from 1.1 to 1.15
450458

451459
# Handle max_new_tokens parameter (map to max_length)
452460
if max_new_tokens is not None:
@@ -503,31 +511,50 @@ async def generate(
503511

504512
with torch.no_grad():
505513
try:
514+
# Generate parameters optimized for high-quality responses
506515
generate_params = {
507516
**inputs,
508517
"max_new_tokens": gen_params["max_length"],
509518
"temperature": gen_params["temperature"],
510519
"top_p": gen_params["top_p"],
520+
"top_k": gen_params.get("top_k", 80), # Default to 80 for better quality
511521
"do_sample": gen_params.get("do_sample", True),
512522
"pad_token_id": self.tokenizer.eos_token_id,
513523
# Fix the early stopping warning by setting num_beams explicitly
514524
"num_beams": 1,
515-
# Add repetition penalty by default for better quality
516-
"repetition_penalty": 1.1
525+
# Add repetition penalty for better quality
526+
"repetition_penalty": gen_params.get("repetition_penalty", 1.15) # Increased from 1.1 to 1.15
517527
}
518528

519-
# Add optional parameters if present in gen_params
520-
if "top_k" in gen_params:
521-
generate_params["top_k"] = gen_params["top_k"]
522-
if "repetition_penalty" in gen_params:
523-
generate_params["repetition_penalty"] = gen_params["repetition_penalty"]
524-
525529
# Set a reasonable max time for generation to prevent hanging
526-
# Use the DEFAULT_MAX_TIME from config (120 seconds)
530+
# Use the DEFAULT_MAX_TIME from config (increased to 180 seconds)
527531
if "max_time" not in generate_params and not stream:
528532
from .config import DEFAULT_MAX_TIME
529533
generate_params["max_time"] = DEFAULT_MAX_TIME # Use the default max time from config
530534

535+
# Define comprehensive stop sequences for proper termination
536+
stop_sequences = [
537+
"</s>", "<|endoftext|>", "<|im_end|>",
538+
"<eos>", "<end>", "<|end|>", "<|EOS|>",
539+
"###", "Assistant:", "Human:", "User:"
540+
]
541+
542+
# Add stop sequences to generation parameters if supported by the model
543+
if hasattr(self.model.config, "stop_token_ids") or hasattr(self.model.generation_config, "stopping_criteria"):
544+
# Convert stop sequences to token IDs
545+
stop_token_ids = []
546+
for seq in stop_sequences:
547+
try:
548+
ids = self.tokenizer.encode(seq, add_special_tokens=False)
549+
if ids:
550+
stop_token_ids.extend(ids)
551+
except:
552+
pass
553+
554+
# Add stop token IDs to generation parameters if supported
555+
if hasattr(self.model.config, "stop_token_ids"):
556+
self.model.config.stop_token_ids = stop_token_ids
557+
531558
# Use efficient attention implementation if available
532559
if hasattr(self.model.config, "attn_implementation"):
533560
generate_params["attn_implementation"] = "flash_attention_2"
@@ -568,16 +595,25 @@ async def generate(
568595
response = response[:marker_pos]
569596
break
570597

571-
# Additional cleanup for any remaining special tokens
572-
special_tokens = ["<|", "|>"]
573-
for token in special_tokens:
574-
if token in response:
575-
# Check if it's part of a special token pattern
576-
pattern = r'<\|[a-zA-Z0-9_]+\|>'
577-
matches = re.finditer(pattern, response)
578-
for match in matches:
579-
# Replace the special token with empty string
580-
response = response.replace(match.group(0), "")
598+
# Additional cleanup for any remaining special tokens using regex
599+
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
600+
response = re.sub(special_token_pattern, '', response)
601+
602+
# Check for repetition patterns that indicate the model is stuck
603+
if len(response) > 200:
604+
# Look for repeating patterns of 20+ characters that repeat 3+ times
605+
for pattern_len in range(20, 40):
606+
if pattern_len < len(response) // 3:
607+
for i in range(len(response) - pattern_len * 3):
608+
pattern = response[i:i+pattern_len]
609+
if pattern and not pattern.isspace():
610+
if response[i:].count(pattern) >= 3:
611+
# Found a repeating pattern, truncate at the second occurrence
612+
second_pos = response.find(pattern, i + pattern_len)
613+
if second_pos > 0:
614+
logger.info(f"Detected repetition pattern, truncating response")
615+
response = response[:second_pos + pattern_len]
616+
break
581617

582618
# Cache the cleaned response if we have a cache key
583619
if cache_key:

locallab/routes/generate.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -137,28 +137,29 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
137137
# Get model-specific generation parameters
138138
model_params = get_model_generation_params(model_manager.current_model)
139139

140-
# Update with request parameters
140+
# Update with request parameters and optimized defaults for high-quality responses
141141
generation_params = {
142142
"max_new_tokens": request.max_tokens,
143143
"temperature": request.temperature,
144-
"top_p": request.top_p,
145-
"top_k": request.top_k,
146-
"repetition_penalty": request.repetition_penalty,
144+
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
145+
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
146+
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
147147
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
148148
}
149149

150150
# Merge model-specific params with request params
151+
# This ensures we get the best of both worlds - model-specific optimizations
152+
# and our high-quality parameters
151153
generation_params.update(model_params)
152154

153-
# Generate text - properly await the async call
155+
# Generate text with optimized parameters - properly await the async call
154156
generated_text = await model_manager.generate_text(
155157
prompt=request.prompt,
156158
system_prompt=request.system_prompt,
157159
**generation_params
158160
)
159161

160162
# Additional cleanup for any special tokens that might have slipped through
161-
import re
162163
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
163164
cleaned_text = re.sub(special_token_pattern, '', generated_text)
164165

@@ -172,6 +173,22 @@ async def generate_text(request: GenerationRequest) -> GenerationResponse:
172173
cleaned_text = cleaned_text[:marker_pos]
173174
break
174175

176+
# Check for repetition patterns that indicate the model is stuck
177+
if len(cleaned_text) > 200:
178+
# Look for repeating patterns of 20+ characters that repeat 3+ times
179+
for pattern_len in range(20, 40):
180+
if pattern_len < len(cleaned_text) // 3:
181+
for i in range(len(cleaned_text) - pattern_len * 3):
182+
pattern = cleaned_text[i:i+pattern_len]
183+
if pattern and not pattern.isspace():
184+
if cleaned_text[i:].count(pattern) >= 3:
185+
# Found a repeating pattern, truncate at the second occurrence
186+
second_pos = cleaned_text.find(pattern, i + pattern_len)
187+
if second_pos > 0:
188+
logger.info(f"Detected repetition pattern in text generation, truncating response")
189+
cleaned_text = cleaned_text[:second_pos + pattern_len]
190+
break
191+
175192
return GenerationResponse(
176193
text=cleaned_text,
177194
model=model_manager.current_model
@@ -203,27 +220,28 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
203220
# Get model-specific generation parameters
204221
model_params = get_model_generation_params(model_manager.current_model)
205222

206-
# Prepare generation parameters
223+
# Prepare generation parameters with optimized defaults for high-quality responses
207224
generation_params = {
208225
"max_new_tokens": request.max_tokens,
209226
"temperature": request.temperature,
210-
"top_p": request.top_p,
211-
"top_k": request.top_k,
212-
"repetition_penalty": request.repetition_penalty,
227+
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
228+
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
229+
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
213230
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
214231
}
215232

216233
# Merge model-specific params with request params
234+
# This ensures we get the best of both worlds - model-specific optimizations
235+
# and our high-quality parameters
217236
generation_params.update(model_params)
218237

219-
# Generate completion
238+
# Generate completion with optimized parameters
220239
generated_text = await model_manager.generate_text(
221240
prompt=formatted_prompt,
222241
**generation_params
223242
)
224243

225244
# Additional cleanup for any special tokens that might have slipped through
226-
import re
227245
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
228246
cleaned_text = re.sub(special_token_pattern, '', generated_text)
229247

@@ -237,6 +255,22 @@ async def chat_completion(request: ChatRequest) -> ChatResponse:
237255
cleaned_text = cleaned_text[:marker_pos]
238256
break
239257

258+
# Check for repetition patterns that indicate the model is stuck
259+
if len(cleaned_text) > 200:
260+
# Look for repeating patterns of 20+ characters that repeat 3+ times
261+
for pattern_len in range(20, 40):
262+
if pattern_len < len(cleaned_text) // 3:
263+
for i in range(len(cleaned_text) - pattern_len * 3):
264+
pattern = cleaned_text[i:i+pattern_len]
265+
if pattern and not pattern.isspace():
266+
if cleaned_text[i:].count(pattern) >= 3:
267+
# Found a repeating pattern, truncate at the second occurrence
268+
second_pos = cleaned_text.find(pattern, i + pattern_len)
269+
if second_pos > 0:
270+
logger.info(f"Detected repetition pattern in chat completion, truncating response")
271+
cleaned_text = cleaned_text[:second_pos + pattern_len]
272+
break
273+
240274
# Format response with cleaned text
241275
return ChatResponse(
242276
choices=[{
@@ -388,7 +422,7 @@ async def stream_chat(
388422
@router.post("/generate/batch", response_model=BatchGenerationResponse)
389423
async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResponse:
390424
"""
391-
Generate text for multiple prompts in a single request
425+
Generate high-quality text for multiple prompts in a single request
392426
"""
393427
if not model_manager.current_model:
394428
raise HTTPException(status_code=400, detail="No model is currently loaded")
@@ -397,29 +431,31 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
397431
# Get model-specific generation parameters
398432
model_params = get_model_generation_params(model_manager.current_model)
399433

400-
# Update with request parameters
434+
# Update with request parameters and optimized defaults for high-quality responses
401435
generation_params = {
402436
"max_new_tokens": request.max_tokens,
403437
"temperature": request.temperature,
404-
"top_p": request.top_p,
405-
"top_k": request.top_k,
406-
"repetition_penalty": request.repetition_penalty,
438+
"top_p": request.top_p if request.top_p is not None else 0.92, # Optimized default
439+
"top_k": request.top_k if request.top_k is not None else 80, # Optimized default
440+
"repetition_penalty": request.repetition_penalty if request.repetition_penalty is not None else 1.15, # Optimized default
407441
"do_sample": model_params.get("do_sample", True) # Pass do_sample from model params
408442
}
409443

410444
# Merge model-specific params with request params
445+
# This ensures we get the best of both worlds - model-specific optimizations
446+
# and our high-quality parameters
411447
generation_params.update(model_params)
412448

413449
responses = []
414450
for prompt in request.prompts:
451+
# Generate text with optimized parameters
415452
generated_text = await model_manager.generate_text(
416453
prompt=prompt,
417454
system_prompt=request.system_prompt,
418455
**generation_params
419456
)
420457

421458
# Additional cleanup for any special tokens that might have slipped through
422-
import re
423459
special_token_pattern = r'<\|[a-zA-Z0-9_]+\|>'
424460
cleaned_text = re.sub(special_token_pattern, '', generated_text)
425461

@@ -433,6 +469,22 @@ async def batch_generate(request: BatchGenerationRequest) -> BatchGenerationResp
433469
cleaned_text = cleaned_text[:marker_pos]
434470
break
435471

472+
# Check for repetition patterns that indicate the model is stuck
473+
if len(cleaned_text) > 200:
474+
# Look for repeating patterns of 20+ characters that repeat 3+ times
475+
for pattern_len in range(20, 40):
476+
if pattern_len < len(cleaned_text) // 3:
477+
for i in range(len(cleaned_text) - pattern_len * 3):
478+
pattern = cleaned_text[i:i+pattern_len]
479+
if pattern and not pattern.isspace():
480+
if cleaned_text[i:].count(pattern) >= 3:
481+
# Found a repeating pattern, truncate at the second occurrence
482+
second_pos = cleaned_text.find(pattern, i + pattern_len)
483+
if second_pos > 0:
484+
logger.info(f"Detected repetition pattern in batch generation, truncating response")
485+
cleaned_text = cleaned_text[:second_pos + pattern_len]
486+
break
487+
436488
responses.append(cleaned_text)
437489

438490
return BatchGenerationResponse(responses=responses)

0 commit comments

Comments
 (0)