Skip to content

Improvements to method of query rewriting #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ async def on_messages_stream(
injected_parameters = {}

# Load the json of the last message to populate the final output object
message_rewrites = json.loads(last_response)
sequential_rounds = json.loads(last_response)

logging.info(f"Query Rewrites: {message_rewrites}")
logging.info(f"Query Rewrites: {sequential_rounds}")

async def consume_inner_messages_from_agentic_flow(
agentic_flow, identifier, filtered_parallel_messages
Expand Down Expand Up @@ -197,7 +197,7 @@ async def consume_inner_messages_from_agentic_flow(

# Convert all_non_database_query to lowercase string and compare
all_non_database_query = str(
message_rewrites.get("all_non_database_query", "false")
sequential_rounds.get("all_non_database_query", "false")
).lower()

if all_non_database_query == "true":
Expand All @@ -210,84 +210,97 @@ async def consume_inner_messages_from_agentic_flow(
return

# Start processing sub-queries
for message_rewrite in message_rewrites["decomposed_user_messages"]:
logging.info(f"Processing sub-query: {message_rewrite}")
# Create an instance of the InnerAutoGenText2Sql class
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)

identifier = ", ".join(message_rewrite)

# Add database connection info to injected parameters
query_params = injected_parameters.copy() if injected_parameters else {}
if "Text2Sql__Tsql__ConnectionString" in os.environ:
query_params["database_connection_string"] = os.environ[
"Text2Sql__Tsql__ConnectionString"
]
if "Text2Sql__Tsql__Database" in os.environ:
query_params["database_name"] = os.environ["Text2Sql__Tsql__Database"]

# Launch tasks for each sub-query
inner_solving_generators.append(
consume_inner_messages_from_agentic_flow(
inner_autogen_text_2_sql.process_user_message(
user_message=message_rewrite,
injected_parameters=query_params,
),
identifier,
filtered_parallel_messages,
for sequential_round in sequential_rounds["decomposed_user_messages"]:
logging.info(f"Processing round: {sequential_round}")

for parallel_message in sequential_round:
logging.info(f"Parallel Message: {parallel_message}")

# Create an instance of the InnerAutoGenText2Sql class
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)

# Add database connection info to injected parameters
query_params = injected_parameters.copy() if injected_parameters else {}
if "Text2Sql__Tsql__ConnectionString" in os.environ:
query_params["database_connection_string"] = os.environ[
"Text2Sql__Tsql__ConnectionString"
]
if "Text2Sql__Tsql__Database" in os.environ:
query_params["database_name"] = os.environ[
"Text2Sql__Tsql__Database"
]

# Launch tasks for each sub-query
inner_solving_generators.append(
consume_inner_messages_from_agentic_flow(
inner_autogen_text_2_sql.process_user_message(
user_message=parallel_message,
injected_parameters=query_params,
database_results=filtered_parallel_messages.database_results,
),
parallel_message,
filtered_parallel_messages,
)
)

logging.info(
"Created %i Inner Solving Generators", len(inner_solving_generators)
)
logging.info("Starting Inner Solving Generators")
combined_message_streams = stream.merge(*inner_solving_generators)

async with combined_message_streams.stream() as streamer:
async for inner_message in streamer:
if isinstance(inner_message, TextMessage):
logging.debug(f"Inner Solving Message: {inner_message}")
yield inner_message

# Log final results for debugging or auditing
logging.info(
"Database Results: %s", filtered_parallel_messages.database_results
)
logging.info(
"Disambiguation Requests: %s",
filtered_parallel_messages.disambiguation_requests,
)

logging.info(
"Created %i Inner Solving Generators", len(inner_solving_generators)
)
logging.info("Starting Inner Solving Generators")
combined_message_streams = stream.merge(*inner_solving_generators)

async with combined_message_streams.stream() as streamer:
async for inner_message in streamer:
if isinstance(inner_message, TextMessage):
logging.debug(f"Inner Solving Message: {inner_message}")
yield inner_message

# Log final results for debugging or auditing
logging.info(
"Database Results: %s", filtered_parallel_messages.database_results
)
logging.info(
"Disambiguation Requests: %s",
filtered_parallel_messages.disambiguation_requests,
)
# Check for disambiguation requests before processing the next round

if (
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
> 0
):
# Final response
yield Response(
chat_message=TextMessage(
content=json.dumps(
{
"contains_disambiguation_requests": True,
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
}
),
source=self.name,
),
)
else:
# Final response
yield Response(
chat_message=TextMessage(
content=json.dumps(
{
"contains_database_results": True,
"database_results": filtered_parallel_messages.database_results,
}
if (
max(
map(
len, filtered_parallel_messages.disambiguation_requests.values()
)
)
> 0
):
# Final response
yield Response(
chat_message=TextMessage(
content=json.dumps(
{
"contains_disambiguation_requests": True,
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
}
),
source=self.name,
),
source=self.name,
)

break

# Final response
yield Response(
chat_message=TextMessage(
content=json.dumps(
{
"contains_database_results": True,
"database_results": filtered_parallel_messages.database_results,
}
),
)
source=self.name,
),
)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ async def on_messages_stream(
try:
request_details = json.loads(messages[0].content)
injected_parameters = request_details["injected_parameters"]
user_messages = request_details["user_message"]
logging.info(f"Processing messages: {user_messages}")
user_message = request_details["user_message"]
logging.info(f"Processing messages: {user_message}")
logging.info(f"Input Parameters: {injected_parameters}")
except json.JSONDecodeError:
# If not JSON array, process as single message
raise ValueError("Could not load message")

cached_results = await self.agent.process_message(
user_messages, injected_parameters
user_message, injected_parameters
)
yield Response(
chat_message=TextMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,14 @@ async def on_messages_stream(
# Try to parse as JSON first
try:
request_details = json.loads(messages[0].content)
messages = request_details["question"]
message = request_details["user_message"]
except (json.JSONDecodeError, KeyError):
# If not JSON or missing question key, use content directly
messages = messages[0].content
message = messages[0].content

if isinstance(messages, str):
messages = [messages]
elif not isinstance(messages, list):
messages = [str(messages)]
logging.info("Processing message: %s", message)

logging.info(f"Processing questions: {messages}")

final_results = await self.agent.process_message(messages)
final_results = await self.agent.process_message(message)

yield Response(
chat_message=TextMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def process_user_message(
self,
user_message: str,
injected_parameters: dict = None,
database_results: dict = None,
):
"""Process the complete question through the unified system.

Expand All @@ -200,6 +201,9 @@ def process_user_message(
"injected_parameters": injected_parameters,
}

if database_results:
agent_input["database_results"] = database_results

return self.agentic_flow.run_stream(task=json.dumps(agent_input))
finally:
# Restore original environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,35 @@ class SqlQueryCacheAgentCustomAgent:
def __init__(self):
self.sql_connector = ConnectorFactory.get_database_connector()

async def process_message(
self, messages: list[str], injected_parameters: dict
) -> dict:
async def process_message(self, message: str, injected_parameters: dict) -> dict:
# Initialize results dictionary
cached_results = {
"cached_sql_queries_with_schemas_from_cache": [],
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
}

# Process each question sequentially
for message in messages:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {message}")
cached_query = (
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
message, injected_parameters=injected_parameters
)
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {message}")
cached_query = (
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
message, injected_parameters=injected_parameters
)
)

# If any question has pre-run results, set the flag
if cached_query.get(
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
False,
):
cached_results[
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
] = True
# If any question has pre-run results, set the flag
if cached_query.get(
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
False,
):
cached_results[
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
] = True

# Add the cached results for this question
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
cached_query["cached_sql_queries_with_schemas_from_cache"]
)
# Add the cached results for this question
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
cached_query["cached_sql_queries_with_schemas_from_cache"]
)

logging.info(f"Final cached results: {cached_results}")
return cached_results
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,35 @@ def __init__(self, **kwargs):

self.system_prompt = Template(system_prompt).render(kwargs)

async def process_message(self, messages: list[str]) -> dict:
logging.info(f"user inputs: {messages}")

entity_tasks = []

for message in messages:
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": message},
]
entity_tasks.append(
self.open_ai_connector.run_completion_request(
messages, response_format=SQLSchemaSelectionAgentOutput
)
)
async def process_message(self, message: str) -> dict:
logging.info(f"Processing message: {message}")

entity_results = await asyncio.gather(*entity_tasks)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": message},
]
entity_result = await self.open_ai_connector.run_completion_request(
messages, response_format=SQLSchemaSelectionAgentOutput
)

entity_search_tasks = []
column_search_tasks = []

for entity_result in entity_results:
logging.info(f"Entity result: {entity_result}")
logging.info(f"Entity result: {entity_result}")

for entity_group in entity_result.entities:
logging.info("Searching for schemas for entity group: %s", entity_group)
entity_search_tasks.append(
self.sql_connector.get_entity_schemas(
" ".join(entity_group), as_json=False
)
for entity_group in entity_result.entities:
logging.info("Searching for schemas for entity group: %s", entity_group)
entity_search_tasks.append(
self.sql_connector.get_entity_schemas(
" ".join(entity_group), as_json=False
)
)

for filter_condition in entity_result.filter_conditions:
logging.info(
"Searching for column values for filter: %s", filter_condition
)
column_search_tasks.append(
self.sql_connector.get_column_values(
filter_condition, as_json=False
)
)
for filter_condition in entity_result.filter_conditions:
logging.info("Searching for column values for filter: %s", filter_condition)
column_search_tasks.append(
self.sql_connector.get_column_values(filter_condition, as_json=False)
)

schemas_results = await asyncio.gather(*entity_search_tasks)
column_value_results = await asyncio.gather(*column_search_tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,5 @@ system_message:
TERMINATE
</output_format>
"
tools:
- sql_get_entity_schemas_tool
Loading
Loading