Skip to content

Commit bace08b

Browse files
committed
refactor!: Replace Pydantic model for LLM config with plain dict
This change replaces the Pydantic model used for the LLM configuration (`llm`) with a plain dictionary, improving flexibility by exposing all parameters directly to the deployer. While it would be preferable to annotate `llm` as `langchain_openai.ChatOpenAI`, there is a limitation: Pydantic raises an error ("cannot pickle '_thread.RLock' object"), likely due to incompatibility within `langchain_openai.ChatOpenAI`. Using a plain dictionary removes validation at startup, but the default empty dict is sufficient to create a `ChatOpenAI` instance, making this an acceptable tradeoff for now.
1 parent 3610817 commit bace08b

File tree

4 files changed

+51
-26
lines changed

4 files changed

+51
-26
lines changed

api/chatbot/config.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import re
2+
from typing import Any
23

3-
from pydantic import BaseModel, HttpUrl, PostgresDsn
4+
from pydantic import Field, PostgresDsn
45
from pydantic_settings import BaseSettings, SettingsConfigDict
56

67

7-
class LLMServiceSettings(BaseModel):
8-
url: HttpUrl = "http://localhost:8080"
9-
"""llm service url"""
10-
model: str = "cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser"
11-
creds: str = "EMPTY"
12-
13-
148
def remove_postgresql_variants(dsn: str) -> str:
159
"""Remove the 'driver' part from a connection string, if one is present in the URI scheme.
1610
@@ -34,7 +28,7 @@ def remove_postgresql_variants(dsn: str) -> str:
3428
class Settings(BaseSettings):
3529
model_config = SettingsConfigDict(env_nested_delimiter="__")
3630

37-
llm: LLMServiceSettings = LLMServiceSettings()
31+
llm: dict[str, Any] = Field(default_factory=dict)
3832
db_url: PostgresDsn = "postgresql+psycopg://postgres:postgres@localhost:5432/"
3933
"""Database url. Must be a valid postgresql connection string."""
4034

api/chatbot/state.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,4 @@
1919
autoflush=False,
2020
class_=AsyncSession,
2121
)
22-
chat_model = ChatOpenAI(
23-
openai_api_base=str(settings.llm.url),
24-
model=settings.llm.model,
25-
openai_api_key=settings.llm.creds,
26-
max_tokens=1024,
27-
streaming=True,
28-
)
22+
chat_model = ChatOpenAI(**settings.llm)

api/tests/test_config.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,57 @@
1-
import os
21
import unittest
3-
from unittest.mock import patch
42

5-
from chatbot.config import Settings
3+
from pydantic import ValidationError
4+
5+
from chatbot.config import Settings, remove_postgresql_variants
6+
7+
8+
class TestRemovePostgresqlVariants(unittest.TestCase):
9+
10+
def test_remove_psycopg(self):
11+
dsn = "postgresql+psycopg://user:pass@localhost/dbname"
12+
expected = "postgresql://user:pass@localhost/dbname"
13+
self.assertEqual(remove_postgresql_variants(dsn), expected)
14+
15+
def test_remove_psycopg2(self):
16+
dsn = "postgresql+psycopg2://user:pass@localhost/dbname"
17+
expected = "postgresql://user:pass@localhost/dbname"
18+
self.assertEqual(remove_postgresql_variants(dsn), expected)
19+
20+
def test_remove_psycopg2cffi(self):
21+
dsn = "postgresql+psycopg2cffi://user:pass@localhost/dbname"
22+
expected = "postgresql://user:pass@localhost/dbname"
23+
self.assertEqual(remove_postgresql_variants(dsn), expected)
24+
25+
def test_no_change(self):
26+
dsn = "postgresql://user:pass@localhost/dbname"
27+
expected = "postgresql://user:pass@localhost/dbname"
28+
self.assertEqual(remove_postgresql_variants(dsn), expected)
629

730

831
class TestSettings(unittest.TestCase):
9-
def test_default_inferece_url(self):
32+
def test_llm_default(self):
1033
settings = Settings()
11-
self.assertEqual(str(settings.llm.url), "http://localhost:8080")
34+
self.assertEqual(settings.llm, {})
35+
36+
def test_llm_custom(self):
37+
custom_llm = {"model": "gpt-3", "version": "davinci"}
38+
settings = Settings(llm=custom_llm)
39+
self.assertEqual(settings.llm, custom_llm)
1240

13-
@patch.dict(os.environ, {"LLM__URL": "http://foo.bar.com"}, clear=True)
14-
def test_inference_server_url(self):
41+
def test_psycopg_url_default(self):
1542
settings = Settings()
16-
self.assertEqual(str(settings.llm.url), "http://foo.bar.com/")
43+
expected = "postgresql://postgres:postgres@localhost:5432/"
44+
self.assertEqual(settings.psycopg_url, expected)
45+
46+
def test_psycopg_url_custom(self):
47+
custom_url = "postgresql+psycopg2://custom_user:custom_pass@localhost/custom_db"
48+
settings = Settings(db_url=custom_url)
49+
expected = "postgresql://custom_user:custom_pass@localhost/custom_db"
50+
self.assertEqual(settings.psycopg_url, expected)
51+
52+
def test_invalid_db_url(self):
53+
with self.assertRaises(ValidationError):
54+
Settings(db_url="invalid_url")
1755

1856

1957
if __name__ == "__main__":

manifests/base/params.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
LLM__URL=http://qwen2dot5-72b-instruct.skynet.svc.cluster.local/v1
2-
LLM__MODEL=qwen2.5-72b-instruct
1+
LLM={"base_url": "http://qwen2dot5-72b-instruct.skynet.svc.cluster.local/v1", "model_name": "qwen2.5-72b-instruct", "api_key": "NOTHING", "temperature": "0.7", "top_p": "0.8", "max_tokens": "1024", "streaming": "True", "stream_usage": "True", "extra_body": {"repetition_penalty": "1.05"}}

0 commit comments

Comments
 (0)