Skip to content

Commit 9530cac

Browse files
author
MeluXina user
committed
contextplus opt-in, mockOpenAI improvements, improved retry logic
1 parent 9198ffb commit 9530cac

3 files changed

Lines changed: 41 additions & 17 deletions

File tree

mallm/models/Chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from langchain_core.language_models.llms import LLM
1313
from langchain_core.outputs import LLMResult
1414
from langchain_core.prompt_values import PromptValue
15-
from openai import APIError, RateLimitError, OpenAI
15+
from openai import APIError, APIConnectionError, RateLimitError, OpenAI
1616

1717
class Chat(LLM): # type: ignore
1818
"""A custom chat model that queries the chat API of HuggingFace Text Generation Inference
@@ -132,7 +132,7 @@ def _call( # type: ignore
132132
collected_messages.append(message_str)
133133
log_prob_sum = log_prob_sum / len(collected_messages)
134134
break
135-
except APIError as e:
135+
except (APIError, APIConnectionError, RateLimitError) as e:
136136
# Handle API error here, e.g. retry or log
137137
retries += 1
138138
if retries < 5:

mallm/models/MockOpenAI.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,30 @@ def _extract_final_solution_from_messages(messages: list[dict[str, str]]) -> str
6464

6565
# Extraction prompts
6666
if "extract the final solution" in last.lower():
67-
# Try to extract from "Your previous response:" if present
68-
m = re.search(r"previous response:\s*(.*)$", last, re.IGNORECASE | re.DOTALL)
69-
if m:
70-
# Return the previous response as-is (test-friendly)
71-
return m.group(1).strip()
72-
# Fallback
73-
return "Final Solution"
67+
# Prefer extracting from the message that actually contains "Your previous response:"
68+
src = ""
69+
for msg in reversed(messages):
70+
content = msg.get("content", "") or ""
71+
if re.search(r"\bprevious response:\b", content, re.IGNORECASE):
72+
src = content
73+
break
74+
if not src:
75+
src = last
76+
77+
m = re.search(r"previous response:\s*(.*)$", src, re.IGNORECASE | re.DOTALL)
78+
extracted = (m.group(1) if m else src).strip()
79+
80+
# If the extraction instruction was merged into the same user message, strip it off.
81+
cut = re.search(r"\n\s*extract the final solution\b", extracted, re.IGNORECASE)
82+
if cut:
83+
extracted = extracted[: cut.start()].strip()
84+
85+
# If the previous response contains a "Final Solution:" marker, return only the solution part.
86+
m2 = re.search(r"final solution:\s*(.*)$", extracted, re.IGNORECASE | re.DOTALL)
87+
if m2:
88+
return m2.group(1).strip()
89+
90+
return extracted or "Final Solution"
7491

7592
# Task-specific simple heuristics
7693
if "capital of france" in joined.lower():

mallm/scheduler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import langchain
2323
import langchain_core
2424
import openai
25-
26-
from contextplus import context
2725
from mallm.models.MockOpenAI import MockOpenAI
2826

2927
try:
@@ -373,8 +371,9 @@ def manage_discussions(self, client: httpx.Client) -> None:
373371
all_model = None
374372
if not str(self.config.endpoint_url).startswith("mock://") and SentenceTransformer is not None:
375373
try:
376-
paraphrase_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
377-
all_model = SentenceTransformer("all-MiniLM-L6-v2")
374+
# Force CPU to avoid contention with the main LLM GPU server
375+
paraphrase_model = SentenceTransformer("paraphrase-MiniLM-L6-v2", device="cpu")
376+
all_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
378377
except Exception:
379378
paraphrase_model = None
380379
all_model = None
@@ -390,11 +389,19 @@ def worker_paraphrase_function(
390389
return [[1.0, 0.0] for _ in input_data]
391390

392391
def worker_context_function(input_data: str) -> str:
393-
# Acquire the lock before using the model
394-
text: str
392+
# Optionally disable heavy context retrieval to avoid GPU memory conflicts.
393+
# Default: disabled, unless explicitly enabled via MALLM_ENABLE_CONTEXT=1.
394+
if os.environ.get("MALLM_DISABLE_CONTEXT", "0") == "1":
395+
return ""
396+
if os.environ.get("MALLM_ENABLE_CONTEXT", "0") != "1":
397+
return ""
398+
# Acquire the lock before using the model; import lazily to avoid loading on module import.
395399
with context_lock:
396-
text = context(input_data)
397-
return text
400+
try:
401+
from contextplus import context as cp_context # type: ignore
402+
return cp_context(input_data)
403+
except Exception:
404+
return ""
398405

399406
def worker_persona_diversity_function(
400407
input_data: list[str],

0 commit comments

Comments
 (0)