Skip to content

Commit 210ee09

Browse files
Fix bad changes (#101)
1 parent 1e2d21f commit 210ee09

File tree

6 files changed

+32
-32
lines changed

6 files changed

+32
-32
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from autogen_agentchat.base import Response
1919
import json
2020
import os
21-
import asyncio
2221
from datetime import datetime
2322

23+
2424
class EmptyResponseUserProxyAgent(UserProxyAgent):
2525
"""UserProxyAgent that automatically responds with empty messages."""
26+
2627
def __init__(self, name):
2728
super().__init__(name=name)
2829
self._has_responded = False
@@ -35,6 +36,7 @@ async def on_messages_stream(self, messages, sender=None, config=None):
3536
yield message
3637
yield Response(chat_message=message)
3738

39+
3840
class AutoGenText2Sql:
3941
def __init__(self, engine_specific_rules: str, **kwargs: dict):
4042
self.use_query_cache = False
@@ -65,32 +67,31 @@ def get_all_agents(self):
6567
"""Get all agents for the complete flow."""
6668
# Get current datetime for the Query Rewrite Agent
6769
current_datetime = datetime.now()
68-
70+
6971
QUERY_REWRITE_AGENT = LLMAgentCreator.create(
70-
"query_rewrite_agent",
71-
current_datetime=current_datetime
72+
"query_rewrite_agent", current_datetime=current_datetime
7273
)
73-
74+
7475
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
7576
"sql_query_generation_agent",
7677
target_engine=self.target_engine,
7778
engine_specific_rules=self.engine_specific_rules,
7879
**self.kwargs,
7980
)
80-
81+
8182
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
8283
target_engine=self.target_engine,
8384
engine_specific_rules=self.engine_specific_rules,
8485
**self.kwargs,
8586
)
86-
87+
8788
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
8889
"sql_query_correction_agent",
8990
target_engine=self.target_engine,
9091
engine_specific_rules=self.engine_specific_rules,
9192
**self.kwargs,
9293
)
93-
94+
9495
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
9596
"sql_disambiguation_agent",
9697
target_engine=self.target_engine,
@@ -101,11 +102,9 @@ def get_all_agents(self):
101102
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
102103
"question_decomposition_agent"
103104
)
104-
105+
105106
# Auto-responding UserProxyAgent
106-
USER_PROXY = EmptyResponseUserProxyAgent(
107-
name="user_proxy"
108-
)
107+
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")
109108

110109
agents = [
111110
USER_PROXY,
@@ -114,7 +113,7 @@ def get_all_agents(self):
114113
SQL_SCHEMA_SELECTION_AGENT,
115114
SQL_QUERY_CORRECTION_AGENT,
116115
QUESTION_DECOMPOSITION_AGENT,
117-
SQL_DISAMBIGUATION_AGENT
116+
SQL_DISAMBIGUATION_AGENT,
118117
]
119118

120119
if self.use_query_cache:
@@ -192,7 +191,6 @@ def agentic_flow(self):
192191
allow_repeated_speaker=False,
193192
model_client=LLMModelCreator.get_model("4o-mini"),
194193
termination_condition=self.termination_condition,
195-
selector_func=self.selector,
196194
selector_func=self.unified_selector,
197195
)
198196
return flow

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,21 @@ async def on_messages_stream(
4747
# Initialize results dictionary
4848
cached_results = {
4949
"cached_questions_and_schemas": [],
50-
"contains_pre_run_results": False
50+
"contains_pre_run_results": False,
5151
}
5252

5353
# Process each question sequentially
5454
for question in user_questions:
5555
# Fetch the queries from the cache based on the question
5656
logging.info(f"Fetching queries from cache for question: {question}")
57-
cached_query = await self.sql_connector.fetch_queries_from_cache(question)
58-
57+
cached_query = await self.sql_connector.fetch_queries_from_cache(
58+
question
59+
)
60+
5961
# If any question has pre-run results, set the flag
6062
if cached_query.get("contains_pre_run_results", False):
6163
cached_results["contains_pre_run_results"] = True
62-
64+
6365
# Add the cached results for this question
6466
if cached_query.get("cached_questions_and_schemas"):
6567
cached_results["cached_questions_and_schemas"].extend(
@@ -75,7 +77,9 @@ async def on_messages_stream(
7577
except json.JSONDecodeError:
7678
# If not JSON array, process as single question
7779
logging.info(f"Processing single question: {last_response}")
78-
cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response)
80+
cached_queries = await self.sql_connector.fetch_queries_from_cache(
81+
last_response
82+
)
7983
yield Response(
8084
chat_message=TextMessage(
8185
content=json.dumps(cached_queries), source=self.name

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,19 @@ async def get_entity_schemas(
198198
logging.info("Search Text: %s", text)
199199

200200
retrieval_fields = [
201-
# "FQN",
201+
"FQN",
202202
"Entity",
203203
"EntityName",
204-
# "Schema",
205-
# "Definition",
206-
"Description",
204+
"Schema",
205+
"Definition",
207206
"Columns",
208207
"EntityRelationships",
209208
"CompleteEntityRelationshipsGraph",
210209
] + engine_specific_fields
211210

212211
schemas = await self.run_ai_search_query(
213212
text,
214-
# ["DefinitionEmbedding"],
215-
["DescriptionEmbedding"],
213+
["DefinitionEmbedding"],
216214
retrieval_fields,
217215
os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"],
218216
os.environ[
@@ -227,7 +225,7 @@ async def get_entity_schemas(
227225
for schema in schemas:
228226
filtered_schemas = []
229227

230-
# del schema["FQN"]
228+
del schema["FQN"]
231229

232230
if (
233231
schema["CompleteEntityRelationshipsGraph"] is not None

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,12 @@ async def get_entity_schemas(
9898
)
9999

100100
for schema in schemas:
101-
# schema["SelectFromEntity"] = ".".join(
102-
# [schema["Catalog"], schema["Schema"], schema["Entity"]]
103-
# )
104-
schema["SelectFromEntity"] = schema["Entity"]
101+
schema["SelectFromEntity"] = ".".join(
102+
[schema["Catalog"], schema["Schema"], schema["Entity"]]
103+
)
105104

106105
del schema["Entity"]
107-
# del schema["Schema"]
106+
del schema["Schema"]
108107
del schema["Catalog"]
109108

110109
if as_json:

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ system_message:
1414
- Use the current date/time above as reference point
1515
- Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates
1616
- Maintain consistency in date formats (YYYY-MM-DD)
17-
17+
1818
Examples of date resolution (assuming current date is {{ current_datetime }}):
1919
- 'last month' -> specific month name and year
2020
- 'this year' -> {{ current_datetime.year }}

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ system_message:
4848
<unsuccessful_mapping_entry>
4949
- If you cannot map it to a column, add en entry to the disambiguation list with the clarification question you need from the user:
5050
- If there are multiple possible options, or you are unsure how it maps, make sure to ask a clarification question.
51+
- If there are no possible options, ask a clarification question for more detail.
5152
5253
{
5354
\"disambiguation\": [

0 commit comments

Comments
 (0)