Skip to content

Commit c83763e

Browse files
authored
Added Cosmos state store (#152)
1 parent 9883baf commit c83763e

File tree

6 files changed

+243
-190
lines changed

6 files changed

+243
-190
lines changed

image_processing/src/image_processing/requirements.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ aiohttp==3.11.11
55
aiosignal==1.3.2
66
annotated-types==0.7.0
77
anyio==4.8.0
8-
attrs==24.3.0
8+
attrs==25.1.0
99
azure-ai-documentintelligence==1.0.0
1010
azure-ai-textanalytics==5.3.0
1111
azure-ai-vision-imageanalysis==1.0.0
@@ -15,7 +15,7 @@ azure-functions==1.21.3
1515
azure-identity==1.19.0
1616
azure-search==1.0.0b2
1717
azure-search-documents==11.6.0b8
18-
azure-storage-blob==12.24.0
18+
azure-storage-blob==12.24.1
1919
beautifulsoup4==4.12.3
2020
blis==0.7.11
2121
bs4==0.0.2
@@ -38,7 +38,7 @@ fsspec==2024.12.0
3838
h11==0.14.0
3939
httpcore==1.0.7
4040
httpx==0.28.1
41-
huggingface-hub==0.27.1
41+
huggingface-hub==0.28.0
4242
idna==3.10
4343
isodate==0.7.2
4444
jinja2==3.1.5
@@ -50,15 +50,15 @@ marisa-trie==1.2.1
5050
markdown-it-py==3.0.0
5151
markupsafe==3.0.2
5252
mdurl==0.1.2
53-
model2vec==0.3.7
53+
model2vec==0.3.8
5454
msal==1.31.1
5555
msal-extensions==1.2.0
5656
msrest==0.7.1
5757
multidict==6.1.0
5858
murmurhash==1.0.12
5959
numpy==1.26.4
6060
oauthlib==3.2.2
61-
openai==1.60.0
61+
openai==1.60.2
6262
openpyxl==3.1.5
6363
packaging==24.2
6464
pandas==2.2.3
@@ -67,7 +67,7 @@ portalocker==2.10.1
6767
preshed==3.0.9
6868
propcache==0.2.1
6969
pycparser==2.22 ; platform_python_implementation != 'PyPy'
70-
pydantic==2.10.5
70+
pydantic==2.10.6
7171
pydantic-core==2.27.2
7272
pygments==2.19.1
7373
pyjwt==2.10.1

text_2_sql/autogen/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"sqlparse>=0.4.4",
1919
"nltk>=3.8.1",
2020
"cachetools>=5.5.1",
21+
"azure-cosmos>=4.9.0",
2122
]
2223

2324
[dependency-groups]

text_2_sql/autogen/src/autogen_text_2_sql/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from autogen_text_2_sql.state_store import InMemoryStateStore
4+
from autogen_text_2_sql.state_store import InMemoryStateStore, CosmosStateStore
55

66
from text_2_sql_core.payloads.interaction_payloads import (
77
UserMessagePayload,
@@ -19,4 +19,5 @@
1919
"ProcessingUpdatePayload",
2020
"InteractionPayload",
2121
"InMemoryStateStore",
22+
"CosmosStateStore",
2223
]

text_2_sql/autogen/src/autogen_text_2_sql/state_store.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from abc import ABC, abstractmethod
22
from cachetools import TTLCache
3+
from azure.cosmos import CosmosClient, exceptions
34

45

56
class StateStore(ABC):
67
@abstractmethod
7-
def get_state(self, thread_id):
8+
def get_state(self, thread_id: str) -> dict:
89
pass
910

1011
@abstractmethod
11-
def save_state(self, thread_id, state):
12+
def save_state(self, thread_id: str, state: dict) -> None:
1213
pass
1314

1415

@@ -21,3 +22,36 @@ def get_state(self, thread_id: str) -> dict:
2122

2223
def save_state(self, thread_id: str, state: dict) -> None:
2324
self.cache[thread_id] = state
25+
26+
27+
class CosmosStateStore(StateStore):
28+
def __init__(self, endpoint, database, container, credential, partition_key=None):
29+
client = CosmosClient(url=endpoint, credential=credential)
30+
database_client = client.get_database_client(database)
31+
self._db = database_client.get_container_client(container)
32+
self.partition_key = partition_key
33+
34+
# Set partition key field name
35+
props = self._db.read()
36+
pk_paths = props["partitionKey"]["paths"]
37+
if len(pk_paths) != 1:
38+
raise ValueError("Only single partition key is supported")
39+
self.partition_key_name = pk_paths[0].lstrip("/")
40+
if "/" in self.partition_key_name:
41+
raise ValueError("Only top-level partition key is supported")
42+
43+
def get_state(self, thread_id: str) -> dict:
44+
try:
45+
item = self._db.read_item(item=thread_id, partition_key=self.partition_key)
46+
return item["state"]
47+
except exceptions.CosmosResourceNotFoundError:
48+
return None
49+
50+
def save_state(self, thread_id: str, state: dict) -> None:
51+
self._db.upsert_item(
52+
body={
53+
self.partition_key_name: self.partition_key,
54+
"id": thread_id,
55+
"state": state,
56+
}
57+
)

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class PayloadBase(InteractionPayloadBase):
4848
class DismabiguationRequestsPayload(InteractionPayloadBase):
4949
class Body(InteractionPayloadBase):
5050
class DismabiguationRequest(InteractionPayloadBase):
51-
assistant_question: str | None = Field(..., alias="AssistantQuestion")
51+
assistant_question: str | None = Field(..., alias="assistantQuestion")
5252
user_choices: list[str] | None = Field(default=None, alias="userChoices")
5353

5454
disambiguation_requests: list[DismabiguationRequest] | None = Field(

0 commit comments

Comments
 (0)