1
1
import json
2
2
import logging
3
3
from abc import ABC , abstractmethod
4
- from typing import Any , List , Optional
4
+ from typing import Any , List , Optional , Type
5
5
6
- from sqlalchemy import Column , Integer , Text , create_engine
6
+ from sqlalchemy import Column , Integer , Select , Text , create_engine , select
7
7
8
8
try :
9
9
from sqlalchemy .orm import declarative_base
22
22
class BaseMessageConverter (ABC ):
23
23
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
24
24
25
+ @abstractmethod
26
+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
27
+ raise NotImplementedError
28
+
25
29
@abstractmethod
26
30
def from_sql_model (self , sql_message : Any ) -> BaseMessage :
27
31
"""Convert a SQLAlchemy model to a BaseMessage instance."""
@@ -51,7 +55,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
51
55
52
56
"""
53
57
54
- # Model decleared inside a function to have a dynamic table name
58
+ # Model declared inside a function to have a dynamic table name
55
59
class Message (DynamicBase ):
56
60
__tablename__ = table_name
57
61
id = Column (Integer , primary_key = True )
@@ -82,6 +86,8 @@ def get_sql_model_class(self) -> Any:
82
86
class SQLChatMessageHistory (BaseChatMessageHistory ):
83
87
"""Chat message history stored in an SQL database."""
84
88
89
+ DEFAULT_MESSAGE_CONVERTER : Type [BaseMessageConverter ] = DefaultMessageConverter
90
+
85
91
def __init__ (
86
92
self ,
87
93
session_id : str ,
@@ -93,7 +99,9 @@ def __init__(
93
99
self .connection_string = connection_string
94
100
self .engine = create_engine (connection_string , echo = False )
95
101
self .session_id_field_name = session_id_field_name
96
- self .converter = custom_message_converter or DefaultMessageConverter (table_name )
102
+ self .converter = custom_message_converter or self .DEFAULT_MESSAGE_CONVERTER (
103
+ table_name
104
+ )
97
105
self .sql_model_class = self .converter .get_sql_model_class ()
98
106
if not hasattr (self .sql_model_class , session_id_field_name ):
99
107
raise ValueError ("SQL model class must have session_id column" )
@@ -105,21 +113,25 @@ def __init__(
105
113
def _create_table_if_not_exists (self ) -> None :
106
114
self .sql_model_class .metadata .create_all (self .engine )
107
115
116
+ def _messages_query (self ) -> Select :
117
+ """Construct an SQLAlchemy selectable to query for messages"""
118
+ return (
119
+ select (self .sql_model_class )
120
+ .where (
121
+ getattr (self .sql_model_class , self .session_id_field_name )
122
+ == self .session_id
123
+ )
124
+ .order_by (self .sql_model_class .id .asc ())
125
+ )
126
+
108
127
@property
109
128
def messages (self ) -> List [BaseMessage ]: # type: ignore
110
129
"""Retrieve all messages from db"""
111
130
with self .Session () as session :
112
- result = (
113
- session .query (self .sql_model_class )
114
- .where (
115
- getattr (self .sql_model_class , self .session_id_field_name )
116
- == self .session_id
117
- )
118
- .order_by (self .sql_model_class .id .asc ())
119
- )
131
+ result = session .execute (self ._messages_query ())
120
132
messages = []
121
133
for record in result :
122
- messages .append (self .converter .from_sql_model (record ))
134
+ messages .append (self .converter .from_sql_model (record [ 0 ] ))
123
135
return messages
124
136
125
137
def add_message (self , message : BaseMessage ) -> None :
0 commit comments