Skip to content

Commit 93ecddb

Browse files
committed
refactor!: Replace Pydantic model for LLM config with plain dict
This change replaces the Pydantic model used for the LLM configuration (`llm`) with a plain dictionary, improving flexibility by exposing all parameters directly to the deployer. While it would be preferable to annotate `llm` as `langchain_openai.ChatOpenAI`, there is a limitation: Pydantic raises an error ("cannot pickle '_thread.RLock' object"), likely due to incompatibility within `langchain_openai.ChatOpenAI`. Using a plain dictionary removes validation at startup, but the default empty dict is sufficient to create a `ChatOpenAI` instance, making this an acceptable tradeoff for now.
1 parent 3610817 commit 93ecddb

File tree

4 files changed

+50
-26
lines changed

4 files changed

+50
-26
lines changed

api/chatbot/config.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import re
2+
from typing import Any
23

3-
from pydantic import BaseModel, HttpUrl, PostgresDsn
4+
from pydantic import Field, PostgresDsn
45
from pydantic_settings import BaseSettings, SettingsConfigDict
56

67

7-
class LLMServiceSettings(BaseModel):
8-
url: HttpUrl = "http://localhost:8080"
9-
"""llm service url"""
10-
model: str = "cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser"
11-
creds: str = "EMPTY"
12-
13-
148
def remove_postgresql_variants(dsn: str) -> str:
159
"""Remove the 'driver' part from a connection string, if one is present in the URI scheme.
1610
@@ -34,7 +28,7 @@ def remove_postgresql_variants(dsn: str) -> str:
3428
class Settings(BaseSettings):
3529
model_config = SettingsConfigDict(env_nested_delimiter="__")
3630

37-
llm: LLMServiceSettings = LLMServiceSettings()
31+
llm: dict[str, Any] = Field(default_factory=lambda: {"api_key": "NOT_SET"})
3832
db_url: PostgresDsn = "postgresql+psycopg://postgres:postgres@localhost:5432/"
3933
"""Database url. Must be a valid postgresql connection string."""
4034

api/chatbot/state.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,4 @@
1919
autoflush=False,
2020
class_=AsyncSession,
2121
)
22-
chat_model = ChatOpenAI(
23-
openai_api_base=str(settings.llm.url),
24-
model=settings.llm.model,
25-
openai_api_key=settings.llm.creds,
26-
max_tokens=1024,
27-
streaming=True,
28-
)
22+
chat_model = ChatOpenAI(**settings.llm)

api/tests/test_config.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,56 @@
1-
import os
21
import unittest
3-
from unittest.mock import patch
42

5-
from chatbot.config import Settings
3+
from pydantic import ValidationError
4+
5+
from chatbot.config import Settings, remove_postgresql_variants
6+
7+
8+
class TestRemovePostgresqlVariants(unittest.TestCase):
9+
def test_remove_psycopg(self):
10+
dsn = "postgresql+psycopg://user:pass@localhost/dbname"
11+
expected = "postgresql://user:pass@localhost/dbname"
12+
self.assertEqual(remove_postgresql_variants(dsn), expected)
13+
14+
def test_remove_psycopg2(self):
15+
dsn = "postgresql+psycopg2://user:pass@localhost/dbname"
16+
expected = "postgresql://user:pass@localhost/dbname"
17+
self.assertEqual(remove_postgresql_variants(dsn), expected)
18+
19+
def test_remove_psycopg2cffi(self):
20+
dsn = "postgresql+psycopg2cffi://user:pass@localhost/dbname"
21+
expected = "postgresql://user:pass@localhost/dbname"
22+
self.assertEqual(remove_postgresql_variants(dsn), expected)
23+
24+
def test_no_change(self):
25+
dsn = "postgresql://user:pass@localhost/dbname"
26+
expected = "postgresql://user:pass@localhost/dbname"
27+
self.assertEqual(remove_postgresql_variants(dsn), expected)
628

729

830
class TestSettings(unittest.TestCase):
9-
def test_default_inferece_url(self):
31+
def test_llm_default(self):
1032
settings = Settings()
11-
self.assertEqual(str(settings.llm.url), "http://localhost:8080")
33+
self.assertEqual(settings.llm, {"api_key": "NOT_SET"})
34+
35+
def test_llm_custom(self):
36+
custom_llm = {"model": "gpt-3", "version": "davinci"}
37+
settings = Settings(llm=custom_llm)
38+
self.assertEqual(settings.llm, custom_llm)
1239

13-
@patch.dict(os.environ, {"LLM__URL": "http://foo.bar.com"}, clear=True)
14-
def test_inference_server_url(self):
40+
def test_psycopg_url_default(self):
1541
settings = Settings()
16-
self.assertEqual(str(settings.llm.url), "http://foo.bar.com/")
42+
expected = "postgresql://postgres:postgres@localhost:5432/"
43+
self.assertEqual(settings.psycopg_url, expected)
44+
45+
def test_psycopg_url_custom(self):
46+
custom_url = "postgresql+psycopg2://custom_user:custom_pass@localhost/custom_db"
47+
settings = Settings(db_url=custom_url)
48+
expected = "postgresql://custom_user:custom_pass@localhost/custom_db"
49+
self.assertEqual(settings.psycopg_url, expected)
50+
51+
def test_invalid_db_url(self):
52+
with self.assertRaises(ValidationError):
53+
Settings(db_url="invalid_url")
1754

1855

1956
if __name__ == "__main__":

manifests/base/params.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
LLM__URL=http://qwen2dot5-72b-instruct.skynet.svc.cluster.local/v1
2-
LLM__MODEL=qwen2.5-72b-instruct
1+
LLM={"base_url": "http://qwen2dot5-72b-instruct.skynet.svc.cluster.local/v1", "model_name": "qwen2.5-72b-instruct", "api_key": "NOTHING", "temperature": "0.7", "top_p": "0.8", "max_tokens": "1024", "streaming": "True", "stream_usage": "True", "extra_body": {"repetition_penalty": "1.05"}}

0 commit comments

Comments
 (0)