Skip to content

Commit ee7a3df

Browse files
committed
Improved Streaming Generation Response Quality
1 parent 7be9307 commit ee7a3df

File tree

3 files changed

+262
-92
lines changed

3 files changed

+262
-92
lines changed

locallab/config.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,16 @@ def save_config(config: Dict[str, Any]):
9393
ENABLE_CORS = get_env_var("ENABLE_CORS", default="true", var_type=bool)
9494
CORS_ORIGINS = get_env_var("CORS_ORIGINS", default="*").split(",")
9595

96-
# Model settings
96+
# Model settings - optimized for high-quality responses
9797
DEFAULT_MODEL = get_env_var("DEFAULT_MODEL", default="microsoft/phi-2")
9898
DEFAULT_MAX_LENGTH = get_env_var(
99-
"DEFAULT_MAX_LENGTH", default=8192, var_type=int) # Increased from 4096 to 8192 for more complete responses
99+
"DEFAULT_MAX_LENGTH", default=8192, var_type=int) # Using 8192 for complete, high-quality responses
100100
DEFAULT_TEMPERATURE = get_env_var(
101-
"DEFAULT_TEMPERATURE", default=0.7, var_type=float)
102-
DEFAULT_TOP_P = get_env_var("DEFAULT_TOP_P", default=0.9, var_type=float)
103-
DEFAULT_TOP_K = get_env_var("DEFAULT_TOP_K", default=80, var_type=int) # Increased from 50 to 80 for better quality
104-
DEFAULT_REPETITION_PENALTY = get_env_var("DEFAULT_REPETITION_PENALTY", default=1.15, var_type=float) # Increased from 1.1 to 1.15
105-
DEFAULT_MAX_TIME = get_env_var("DEFAULT_MAX_TIME", default=120.0, var_type=float) # Added default max_time of 120 seconds
101+
"DEFAULT_TEMPERATURE", default=0.7, var_type=float) # Balanced temperature for quality and creativity
102+
DEFAULT_TOP_P = get_env_var("DEFAULT_TOP_P", default=0.92, var_type=float) # Increased from 0.9 to 0.92 for better quality
103+
DEFAULT_TOP_K = get_env_var("DEFAULT_TOP_K", default=80, var_type=int) # Using 80 for better quality responses
104+
DEFAULT_REPETITION_PENALTY = get_env_var("DEFAULT_REPETITION_PENALTY", default=1.15, var_type=float) # Using 1.15 to prevent repetition while allowing natural patterns
105+
DEFAULT_MAX_TIME = get_env_var("DEFAULT_MAX_TIME", default=180.0, var_type=float) # Increased from 120 to 180 seconds for more complete responses
106106

107107
# Optimization settings
108108
ENABLE_QUANTIZATION = get_env_var(
@@ -466,23 +466,26 @@ def estimate_model_requirements(model_id: str) -> Optional[Dict[str, Any]]:
466466
def get_model_generation_params(model_id: Optional[str] = None) -> dict:
467467
"""Get model generation parameters, optionally specific to a model.
468468
469-
This function prioritizes quality and completeness of responses by using
470-
higher max_length and appropriate repetition_penalty values.
469+
This function prioritizes high-quality, complete responses by using
470+
optimized parameters for temperature, top_p, top_k, repetition_penalty,
471+
and max_length.
471472
472473
Args:
473474
model_id: Optional model ID to get specific parameters for
474475
475476
Returns:
476-
Dictionary of generation parameters optimized for complete responses
477+
Dictionary of generation parameters optimized for high-quality responses
477478
"""
478-
# Base parameters (defaults) - optimized for quality and completeness
479+
# Base parameters (defaults) - optimized for high-quality responses
479480
params = {
480481
"max_length": get_env_var("LOCALLAB_MODEL_MAX_LENGTH", default=DEFAULT_MAX_LENGTH, var_type=int),
481482
"temperature": get_env_var("LOCALLAB_MODEL_TEMPERATURE", default=DEFAULT_TEMPERATURE, var_type=float),
482483
"top_p": get_env_var("LOCALLAB_MODEL_TOP_P", default=DEFAULT_TOP_P, var_type=float),
483484
"top_k": get_env_var("LOCALLAB_TOP_K", default=DEFAULT_TOP_K, var_type=int),
484485
"repetition_penalty": get_env_var("LOCALLAB_REPETITION_PENALTY", default=DEFAULT_REPETITION_PENALTY, var_type=float),
485-
# Add do_sample parameter to ensure proper sampling
486+
# Add max_time parameter to ensure responses have enough time to complete
487+
"max_time": get_env_var("LOCALLAB_MAX_TIME", default=DEFAULT_MAX_TIME, var_type=float),
488+
# Add do_sample parameter to ensure proper sampling for high-quality responses
486489
"do_sample": True
487490
}
488491

@@ -491,11 +494,12 @@ def get_model_generation_params(model_id: Optional[str] = None) -> dict:
491494
model_config = MODEL_REGISTRY[model_id]
492495
# Override with model-specific parameters if available
493496
if "max_length" in model_config:
494-
# Ensure max_length is at least 1024 for complete responses
495-
params["max_length"] = max(model_config["max_length"], 1024)
497+
# Ensure max_length is at least 2048 for high-quality, complete responses
498+
# Increased from 1024 to 2048 for better quality
499+
params["max_length"] = max(model_config["max_length"], 2048)
496500
else:
497501
# If no model-specific max_length, use a reasonable default
498-
params["max_length"] = max(DEFAULT_MAX_LENGTH, 1024)
502+
params["max_length"] = max(DEFAULT_MAX_LENGTH, 2048) # Increased from 1024 to 2048
499503

500504
# Add any other model-specific parameters from the registry
501505
for param in ["temperature", "top_p", "top_k", "repetition_penalty"]:

0 commit comments

Comments
 (0)