Skip to content

Commit 460bb46

Browse files
committed
refactor: separate read / write sessions.
1 parent 93ecddb commit 460bb46

File tree

8 files changed

+109
-27
lines changed

8 files changed

+109
-27
lines changed

api/chatbot/config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,22 @@ class Settings(BaseSettings):
2929
model_config = SettingsConfigDict(env_nested_delimiter="__")
3030

3131
llm: dict[str, Any] = Field(default_factory=lambda: {"api_key": "NOT_SET"})
32-
db_url: PostgresDsn = "postgresql+psycopg://postgres:postgres@localhost:5432/"
33-
"""Database url. Must be a valid postgresql connection string."""
32+
postgres_primary_url: PostgresDsn = (
33+
"postgresql+psycopg://postgres:postgres@localhost:5432/"
34+
)
35+
"""Primary database url. Read/Write. Must be a valid postgresql connection string."""
36+
postgres_standby_url: PostgresDsn = postgres_primary_url
37+
"""Standby database url. Read Only. If present must be a valid postgresql connection string.
38+
Defaults to `postgres_primary_url`.
39+
"""
40+
41+
@property
42+
def psycopg_primary_url(self) -> str:
43+
return remove_postgresql_variants(str(self.postgres_primary_url))
3444

3545
@property
36-
def psycopg_url(self) -> str:
37-
return remove_postgresql_variants(str(self.db_url))
46+
def psycopg_standby_url(self) -> str:
47+
return remove_postgresql_variants(str(self.postgres_standby_url))
3848

3949

4050
settings = Settings()

api/chatbot/dependencies.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from chatbot.agent import create_agent
1111
from chatbot.config import settings
12-
from chatbot.state import chat_model, sqlalchemy_session
12+
from chatbot.state import chat_model, sqlalchemy_session, sqlalchemy_ro_session
1313

1414

1515
def UserIdHeader(alias: str | None = None, **kwargs):
@@ -47,9 +47,17 @@ async def get_sqlalchemy_session() -> AsyncGenerator[AsyncSession, None]:
4747
SqlalchemySessionDep = Annotated[AsyncSession, Depends(get_sqlalchemy_session)]
4848

4949

50+
async def get_sqlalchemy_ro_session() -> AsyncGenerator[AsyncSession, None]:
51+
async with sqlalchemy_ro_session() as session:
52+
yield session
53+
54+
55+
SqlalchemyROSessionDep = Annotated[AsyncSession, Depends(get_sqlalchemy_ro_session)]
56+
57+
5058
async def get_agent() -> AsyncGenerator[CompiledGraph, None]:
5159
async with AsyncPostgresSaver.from_conn_string(
52-
settings.psycopg_url
60+
settings.psycopg_primary_url
5361
) as checkpointer:
5462
yield create_agent(chat_model, checkpointer)
5563

api/chatbot/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def lifespan(app: FastAPI):
3131

3232
# Create checkpointer tables
3333
async with AsyncPostgresSaver.from_conn_string(
34-
settings.psycopg_url
34+
settings.psycopg_primary_url
3535
) as checkpointer:
3636
await checkpointer.setup()
3737

api/chatbot/routers/conversation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from sqlalchemy import select
44

55
from chatbot.chains.summarization import create_smry_chain
6-
from chatbot.dependencies import AgentStateDep, SqlalchemySessionDep, UserIdHeaderDep
6+
from chatbot.dependencies import (
7+
AgentStateDep,
8+
SqlalchemyROSessionDep,
9+
SqlalchemySessionDep,
10+
UserIdHeaderDep,
11+
)
712
from chatbot.models import Conversation as ORMConversation
813
from chatbot.schemas import (
914
ChatMessage,
@@ -23,7 +28,7 @@
2328
@router.get("")
2429
async def get_conversations(
2530
userid: UserIdHeaderDep,
26-
session: SqlalchemySessionDep,
31+
session: SqlalchemyROSessionDep,
2732
) -> list[Conversation]:
2833
# TODO: support pagination
2934
stmt = (
@@ -38,7 +43,7 @@ async def get_conversations(
3843
async def get_conversation(
3944
conversation_id: str,
4045
userid: UserIdHeaderDep,
41-
session: SqlalchemySessionDep,
46+
session: SqlalchemyROSessionDep,
4247
agent_state: AgentStateDep,
4348
) -> ConversationDetail:
4449
conv: ORMConversation = await session.get(ORMConversation, conversation_id)
@@ -110,7 +115,7 @@ async def delete_conversation(
110115
async def summarize(
111116
conversation_id: str,
112117
userid: UserIdHeaderDep,
113-
session: SqlalchemySessionDep,
118+
session: SqlalchemyROSessionDep,
114119
agent_state: AgentStateDep,
115120
) -> dict[str, str]:
116121
conv: ORMConversation = await session.get(ORMConversation, conversation_id)

api/chatbot/routers/message.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import APIRouter, HTTPException
22
from langchain_core.messages import BaseMessage
33

4-
from chatbot.dependencies import AgentDep, SqlalchemySessionDep, UserIdHeaderDep
4+
from chatbot.dependencies import AgentDep, SqlalchemyROSessionDep, UserIdHeaderDep
55
from chatbot.models import Conversation as ORMConversation
66

77
router = APIRouter(
@@ -16,7 +16,7 @@ async def thumbup(
1616
conversation_id: str,
1717
message_id: str,
1818
userid: UserIdHeaderDep,
19-
session: SqlalchemySessionDep,
19+
session: SqlalchemyROSessionDep,
2020
agent: AgentDep,
2121
) -> None:
2222
"""Using message index as the uuid is in the message body which is json dumped into redis,
@@ -48,7 +48,7 @@ async def thumbdown(
4848
conversation_id: str,
4949
message_id: str,
5050
userid: UserIdHeaderDep,
51-
session: SqlalchemySessionDep,
51+
session: SqlalchemyROSessionDep,
5252
agent: AgentDep,
5353
) -> None:
5454
"""Using message index as the uuid is in the message body which is json dumped into redis,

api/chatbot/routers/share.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from sqlalchemy import select
88
from starlette.requests import Request
99

10-
from chatbot.dependencies import AgentDep, SqlalchemySessionDep, UserIdHeaderDep
10+
from chatbot.dependencies import (
11+
AgentDep,
12+
SqlalchemyROSessionDep,
13+
SqlalchemySessionDep,
14+
UserIdHeaderDep,
15+
)
1116
from chatbot.models import Conversation as ORMConv, Share as ORMShare
1217
from chatbot.schemas import ChatMessage, CreateShare, Share
1318

@@ -22,7 +27,7 @@
2227
@router.get("")
2328
async def get_shares(
2429
userid: UserIdHeaderDep,
25-
session: SqlalchemySessionDep,
30+
session: SqlalchemyROSessionDep,
2631
) -> list[Share]:
2732
"""Get shares by userid"""
2833
# TODO: support pagination
@@ -37,7 +42,7 @@ async def get_shares(
3742
@router.get("/{share_id}")
3843
async def get_share(
3944
share_id: str,
40-
session: SqlalchemySessionDep,
45+
session: SqlalchemyROSessionDep,
4146
agent: AgentDep,
4247
) -> Share:
4348
"""Get a share by id"""

api/chatbot/state.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
sqlalchemy_engine = create_async_engine(
12-
str(settings.db_url),
12+
str(settings.postgres_primary_url),
1313
poolclass=NullPool,
1414
)
1515
sqlalchemy_session = sessionmaker(
@@ -19,4 +19,15 @@
1919
autoflush=False,
2020
class_=AsyncSession,
2121
)
22+
sqlalchemy_ro_engine = create_async_engine(
23+
str(settings.postgres_standby_url),
24+
poolclass=NullPool,
25+
)
26+
sqlalchemy_ro_session = sessionmaker(
27+
sqlalchemy_ro_engine,
28+
autocommit=False,
29+
expire_on_commit=False,
30+
autoflush=False,
31+
class_=AsyncSession,
32+
)
2233
chat_model = ChatOpenAI(**settings.llm)

api/tests/test_config.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,63 @@ def test_llm_custom(self):
3737
settings = Settings(llm=custom_llm)
3838
self.assertEqual(settings.llm, custom_llm)
3939

40-
def test_psycopg_url_default(self):
40+
def test_postgres_primary_url_default(self):
4141
settings = Settings()
42-
expected = "postgresql://postgres:postgres@localhost:5432/"
43-
self.assertEqual(settings.psycopg_url, expected)
42+
expected = "postgresql+psycopg://postgres:postgres@localhost:5432/"
43+
self.assertEqual(str(settings.postgres_primary_url), expected)
44+
45+
def test_postgres_primary_url_custom(self):
46+
custom_primary_url = (
47+
"postgresql+psycopg://primary_user:primary_pass@localhost/primary_db"
48+
)
49+
settings = Settings(postgres_primary_url=custom_primary_url)
50+
expected = "postgresql+psycopg://primary_user:primary_pass@localhost/primary_db"
51+
self.assertEqual(str(settings.postgres_primary_url), expected)
52+
53+
def test_postgres_standby_url_default(self):
54+
settings = Settings()
55+
expected = "postgresql+psycopg://postgres:postgres@localhost:5432/"
56+
self.assertEqual(str(settings.postgres_standby_url), expected)
4457

45-
def test_psycopg_url_custom(self):
46-
custom_url = "postgresql+psycopg2://custom_user:custom_pass@localhost/custom_db"
47-
settings = Settings(db_url=custom_url)
48-
expected = "postgresql://custom_user:custom_pass@localhost/custom_db"
49-
self.assertEqual(settings.psycopg_url, expected)
58+
def test_postgres_standby_url_custom(self):
59+
custom_standby_url = (
60+
"postgresql+psycopg://standby_user:standby_pass@localhost/standby_db"
61+
)
62+
settings = Settings(postgres_standby_url=custom_standby_url)
63+
expected = "postgresql+psycopg://standby_user:standby_pass@localhost/standby_db"
64+
self.assertEqual(str(settings.postgres_standby_url), expected)
5065

5166
def test_invalid_db_url(self):
5267
with self.assertRaises(ValidationError):
53-
Settings(db_url="invalid_url")
68+
Settings(postgres_primary_url="invalid_url")
69+
with self.assertRaises(ValidationError):
70+
Settings(postgres_standby_url="invalid_url")
71+
72+
def test_psycopg_primary_url_default(self):
73+
settings = Settings()
74+
expected = "postgresql://postgres:postgres@localhost:5432/"
75+
self.assertEqual(settings.psycopg_primary_url, expected)
76+
77+
def test_psycopg_primary_url_custom(self):
78+
custom_primary_url = (
79+
"postgresql+psycopg://primary_user:primary_pass@localhost/primary_db"
80+
)
81+
settings = Settings(postgres_primary_url=custom_primary_url)
82+
expected = "postgresql://primary_user:primary_pass@localhost/primary_db"
83+
self.assertEqual(settings.psycopg_primary_url, expected)
84+
85+
def test_psycopg_standby_url_default(self):
86+
settings = Settings()
87+
expected = "postgresql://postgres:postgres@localhost:5432/"
88+
self.assertEqual(settings.psycopg_standby_url, expected)
89+
90+
def test_psycopg_standby_url_custom(self):
91+
custom_standby_url = (
92+
"postgresql+psycopg://standby_user:standby_pass@localhost/standby_db"
93+
)
94+
settings = Settings(postgres_standby_url=custom_standby_url)
95+
expected = "postgresql://standby_user:standby_pass@localhost/standby_db"
96+
self.assertEqual(settings.psycopg_standby_url, expected)
5497

5598

5699
if __name__ == "__main__":

0 commit comments

Comments
 (0)