Skip to content

Commit 27a2193

Browse files
committed
Fix
1 parent a70fa78 commit 27a2193

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from autogen_agentchat.base import Response
1919
import json
2020
import os
21-
import asyncio
2221
from datetime import datetime
2322

23+
2424
class EmptyResponseUserProxyAgent(UserProxyAgent):
2525
"""UserProxyAgent that automatically responds with empty messages."""
26+
2627
def __init__(self, name):
2728
super().__init__(name=name)
2829
self._has_responded = False
@@ -35,6 +36,7 @@ async def on_messages_stream(self, messages, sender=None, config=None):
3536
yield message
3637
yield Response(chat_message=message)
3738

39+
3840
class AutoGenText2Sql:
3941
def __init__(self, engine_specific_rules: str, **kwargs: dict):
4042
self.use_query_cache = False
@@ -65,32 +67,31 @@ def get_all_agents(self):
6567
"""Get all agents for the complete flow."""
6668
# Get current datetime for the Query Rewrite Agent
6769
current_datetime = datetime.now()
68-
70+
6971
QUERY_REWRITE_AGENT = LLMAgentCreator.create(
70-
"query_rewrite_agent",
71-
current_datetime=current_datetime
72+
"query_rewrite_agent", current_datetime=current_datetime
7273
)
73-
74+
7475
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
7576
"sql_query_generation_agent",
7677
target_engine=self.target_engine,
7778
engine_specific_rules=self.engine_specific_rules,
7879
**self.kwargs,
7980
)
80-
81+
8182
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
8283
target_engine=self.target_engine,
8384
engine_specific_rules=self.engine_specific_rules,
8485
**self.kwargs,
8586
)
86-
87+
8788
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
8889
"sql_query_correction_agent",
8990
target_engine=self.target_engine,
9091
engine_specific_rules=self.engine_specific_rules,
9192
**self.kwargs,
9293
)
93-
94+
9495
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
9596
"sql_disambiguation_agent",
9697
target_engine=self.target_engine,
@@ -101,11 +102,9 @@ def get_all_agents(self):
101102
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
102103
"question_decomposition_agent"
103104
)
104-
105+
105106
# Auto-responding UserProxyAgent
106-
USER_PROXY = EmptyResponseUserProxyAgent(
107-
name="user_proxy"
108-
)
107+
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")
109108

110109
agents = [
111110
USER_PROXY,
@@ -114,7 +113,7 @@ def get_all_agents(self):
114113
SQL_SCHEMA_SELECTION_AGENT,
115114
SQL_QUERY_CORRECTION_AGENT,
116115
QUESTION_DECOMPOSITION_AGENT,
117-
SQL_DISAMBIGUATION_AGENT
116+
SQL_DISAMBIGUATION_AGENT,
118117
]
119118

120119
if self.use_query_cache:
@@ -192,7 +191,6 @@ def agentic_flow(self):
192191
allow_repeated_speaker=False,
193192
model_client=LLMModelCreator.get_model("4o-mini"),
194193
termination_condition=self.termination_condition,
195-
selector_func=self.selector,
196194
selector_func=self.unified_selector,
197195
)
198196
return flow

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,21 @@ async def on_messages_stream(
4747
# Initialize results dictionary
4848
cached_results = {
4949
"cached_questions_and_schemas": [],
50-
"contains_pre_run_results": False
50+
"contains_pre_run_results": False,
5151
}
5252

5353
# Process each question sequentially
5454
for question in user_questions:
5555
# Fetch the queries from the cache based on the question
5656
logging.info(f"Fetching queries from cache for question: {question}")
57-
cached_query = await self.sql_connector.fetch_queries_from_cache(question)
58-
57+
cached_query = await self.sql_connector.fetch_queries_from_cache(
58+
question
59+
)
60+
5961
# If any question has pre-run results, set the flag
6062
if cached_query.get("contains_pre_run_results", False):
6163
cached_results["contains_pre_run_results"] = True
62-
64+
6365
# Add the cached results for this question
6466
if cached_query.get("cached_questions_and_schemas"):
6567
cached_results["cached_questions_and_schemas"].extend(
@@ -75,7 +77,9 @@ async def on_messages_stream(
7577
except json.JSONDecodeError:
7678
# If not JSON array, process as single question
7779
logging.info(f"Processing single question: {last_response}")
78-
cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response)
80+
cached_queries = await self.sql_connector.fetch_queries_from_cache(
81+
last_response
82+
)
7983
yield Response(
8084
chat_message=TextMessage(
8185
content=json.dumps(cached_queries), source=self.name

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ system_message:
1414
- Use the current date/time above as reference point
1515
- Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates
1616
- Maintain consistency in date formats (YYYY-MM-DD)
17-
17+
1818
Examples of date resolution (assuming current date is {{ current_datetime }}):
1919
- 'last month' -> specific month name and year
2020
- 'this year' -> {{ current_datetime.year }}

0 commit comments

Comments
 (0)