Skip to content

Commit 571ce1f

Browse files
committed
Generalize SQLChatMessageHistory to make code a bit more reusable
1 parent ef21cde commit 571ce1f

File tree

1 file changed

+25
-13
lines changed
  • libs/langchain/langchain/memory/chat_message_histories

1 file changed

+25
-13
lines changed

libs/langchain/langchain/memory/chat_message_histories/sql.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import json
22
import logging
33
from abc import ABC, abstractmethod
4-
from typing import Any, List, Optional
4+
from typing import Any, List, Optional, Type
55

6-
from sqlalchemy import Column, Integer, Text, create_engine
6+
from sqlalchemy import Column, Integer, Select, Text, create_engine, select
77

88
try:
99
from sqlalchemy.orm import declarative_base
@@ -22,6 +22,10 @@
2222
class BaseMessageConverter(ABC):
2323
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
2424

25+
@abstractmethod
26+
def __init__(self, *args: Any, **kwargs: Any) -> None:
27+
raise NotImplementedError
28+
2529
@abstractmethod
2630
def from_sql_model(self, sql_message: Any) -> BaseMessage:
2731
"""Convert a SQLAlchemy model to a BaseMessage instance."""
@@ -51,7 +55,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
5155
5256
"""
5357

54-
# Model decleared inside a function to have a dynamic table name
58+
# Model declared inside a function to have a dynamic table name
5559
class Message(DynamicBase):
5660
__tablename__ = table_name
5761
id = Column(Integer, primary_key=True)
@@ -82,6 +86,8 @@ def get_sql_model_class(self) -> Any:
8286
class SQLChatMessageHistory(BaseChatMessageHistory):
8387
"""Chat message history stored in an SQL database."""
8488

89+
DEFAULT_MESSAGE_CONVERTER: Type[BaseMessageConverter] = DefaultMessageConverter
90+
8591
def __init__(
8692
self,
8793
session_id: str,
@@ -93,7 +99,9 @@ def __init__(
9399
self.connection_string = connection_string
94100
self.engine = create_engine(connection_string, echo=False)
95101
self.session_id_field_name = session_id_field_name
96-
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
102+
self.converter = custom_message_converter or self.DEFAULT_MESSAGE_CONVERTER(
103+
table_name
104+
)
97105
self.sql_model_class = self.converter.get_sql_model_class()
98106
if not hasattr(self.sql_model_class, session_id_field_name):
99107
raise ValueError("SQL model class must have session_id column")
@@ -105,21 +113,25 @@ def __init__(
105113
def _create_table_if_not_exists(self) -> None:
106114
self.sql_model_class.metadata.create_all(self.engine)
107115

116+
def _messages_query(self) -> Select:
117+
"""Construct an SQLAlchemy selectable to query for messages"""
118+
return (
119+
select(self.sql_model_class)
120+
.where(
121+
getattr(self.sql_model_class, self.session_id_field_name)
122+
== self.session_id
123+
)
124+
.order_by(self.sql_model_class.id.asc())
125+
)
126+
108127
@property
109128
def messages(self) -> List[BaseMessage]: # type: ignore
110129
"""Retrieve all messages from db"""
111130
with self.Session() as session:
112-
result = (
113-
session.query(self.sql_model_class)
114-
.where(
115-
getattr(self.sql_model_class, self.session_id_field_name)
116-
== self.session_id
117-
)
118-
.order_by(self.sql_model_class.id.asc())
119-
)
131+
result = session.execute(self._messages_query())
120132
messages = []
121133
for record in result:
122-
messages.append(self.converter.from_sql_model(record))
134+
messages.append(self.converter.from_sql_model(record[0]))
123135
return messages
124136

125137
def add_message(self, message: BaseMessage) -> None:

0 commit comments

Comments
 (0)