diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 4e49cdc6..d4f22d4c 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -50,7 +50,7 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql" + "from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody" ] }, { @@ -100,7 +100,7 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n", + "async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index defc348d..37a3b01b 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql +from text_2_sql_core.payloads.agent_response import AgentRequestBody -__all__ = ["AutoGenText2Sql"] +__all__ = ["AutoGenText2Sql", "AgentRequestBody"] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index b021b01d..4c72be37 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -18,11 +18,17 @@ import os from datetime import datetime -from text_2_sql_core.payloads import ( +from text_2_sql_core.payloads.agent_response import ( + AgentResponse, + AgentRequestBody, AnswerWithSources, - UserInformationRequest, + Source, + DismabiguationRequests, +) +from text_2_sql_core.payloads.chat_history import ChatHistoryItem +from text_2_sql_core.payloads.processing_update import ( + ProcessingUpdateBody, ProcessingUpdate, - ChatHistoryItem, ) from autogen_agentchat.base import Response, TaskResult from typing import AsyncGenerator @@ -123,6 +129,16 @@ def agentic_flow(self): ) return flow + def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests: + """Extract the disambiguation request from the answer.""" + + disambiguation_request = messages[-1].content + + # TODO: Properly extract the disambiguation request + return DismabiguationRequests( + disambiguation_request=disambiguation_request, + ) + def extract_sources(self, messages: list) -> AnswerWithSources: """Extract the sources from the answer.""" @@ -147,10 +163,10 @@ def extract_sources(self, messages: list) -> AnswerWithSources: for sql_query_result in sql_query_result_list: logging.info("SQL Query Result: %s", sql_query_result) sources.append( - { - "sql_query": sql_query_result["sql_query"], - "sql_rows": sql_query_result["sql_rows"], - } + Source( + sql_query=sql_query_result["sql_query"], + sql_rows=sql_query_result["sql_rows"], + ) ) except json.JSONDecodeError: @@ -164,10 +180,9 @@ def extract_sources(self, messages: list) -> AnswerWithSources: async def process_question( self, - question: str, + request: AgentRequestBody, chat_history: list[ChatHistoryItem] = None, - injected_parameters: dict = None, - ) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]: + ) -> AsyncGenerator[AgentResponse | ProcessingUpdate, None]: """Process the complete question through the unified system. Args: @@ -180,20 +195,20 @@ async def process_question( ------- dict: The response from the system. """ - logging.info("Processing question: %s", question) + logging.info("Processing question: %s", request.question) logging.info("Chat history: %s", chat_history) agent_input = { - "question": question, + "question": request.question, "chat_history": {}, - "injected_parameters": injected_parameters, + "injected_parameters": request.injected_parameters, } if chat_history is not None: # Update input for idx, chat in enumerate(chat_history): # For now only consider the user query - agent_input[f"chat_{idx}"] = chat.user_query + agent_input[f"chat_{idx}"] = chat.request.question async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -201,37 +216,42 @@ async def process_question( payload = None if isinstance(message, TextMessage): + processing_update = None if message.source == "query_rewrite_agent": - # If the message is from the query_rewrite_agent, we need to update the chat history - payload = ProcessingUpdate( + processing_update = ProcessingUpdateBody( message="Rewriting the query...", ) elif message.source == "parallel_query_solving_agent": - # If the message is from the parallel_query_solving_agent, we need to update the chat history - payload = ProcessingUpdate( + processing_update = ProcessingUpdateBody( message="Solving the query...", ) elif message.source == "answer_agent": - # If the message is from the answer_agent, we need to update the chat history - payload = ProcessingUpdate( + processing_update = ProcessingUpdateBody( message="Generating the answer...", ) + if processing_update is not None: + payload = ProcessingUpdate( + processing_update=processing_update, + ) + elif isinstance(message, TaskResult): # Now we need to return the final answer or the disambiguation request logging.info("TaskResult: %s", message) + response = None if message.messages[-1].source == "answer_agent": # If the message is from the answer_agent, we need to return the final answer - payload = self.extract_sources(message.messages) + response = self.extract_sources(message.messages) elif message.messages[-1].source == "parallel_query_solving_agent": - payload = UserInformationRequest( - **json.loads(message.messages[-1].content), - ) + # Load into disambiguation request + response = self.extract_disambiguation_request(message.messages) else: logging.error("Unexpected TaskResult: %s", message) raise ValueError("Unexpected TaskResult") + payload = AgentResponse(request=request, response=response) + if payload is not None: logging.debug("Payload: %s", payload) yield payload diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index a745fa01..6089d89d 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -42,18 +42,13 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): elif tool_name == "sql_get_entity_schemas_tool": return FunctionToolAlias( sql_helper.get_entity_schemas, - description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.", + description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.", ) elif tool_name == "sql_get_column_values_tool": return FunctionToolAlias( ai_search_helper.get_column_values, description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.", ) - elif tool_name == "current_datetime_tool": - return FunctionToolAlias( - sql_helper.get_current_datetime, - description="Gets the current date and time.", - ) else: raise ValueError(f"Tool {tool_name} not found") diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 5a5c7078..2be891c2 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -44,10 +44,10 @@ async def on_messages_stream( last_response = messages[-1].content parameter_input = messages[0].content try: - user_parameters = json.loads(parameter_input)["parameters"] + injected_parameters = json.loads(parameter_input)["injected_parameters"] except json.JSONDecodeError: logging.error("Error decoding the user parameters.") - user_parameters = {} + injected_parameters = {} # Load the json of the last message to populate the final output object query_rewrites = json.loads(last_response) @@ -75,7 +75,7 @@ async def consume_inner_messages_from_agentic_flow( if isinstance(inner_message, TaskResult) is False: try: inner_message = json.loads(inner_message.content) - logging.info(f"Loaded: {inner_message}") + logging.info(f"Inner Loaded: {inner_message}") # Search for specific message types and add them to the final output object if ( @@ -91,6 +91,21 @@ async def consume_inner_messages_from_agentic_flow( } ) + if ("contains_pre_run_results" in inner_message) and ( + inner_message["contains_pre_run_results"] is True + ): + for pre_run_sql_query, pre_run_result in inner_message[ + "cached_questions_and_schemas" + ].items(): + database_results[identifier].append( + { + "sql_query": pre_run_sql_query.replace( + "\n", " " + ), + "sql_rows": pre_run_result["sql_rows"], + } + ) + except (JSONDecodeError, TypeError) as e: logging.error("Could not load message: %s", inner_message) logging.warning(f"Error processing message: {e}") @@ -113,13 +128,15 @@ async def consume_inner_messages_from_agentic_flow( self.engine_specific_rules, **self.kwargs ) + identifier = ", ".join(query_rewrite) + # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( inner_autogen_text_2_sql.process_question( - question=query_rewrite, parameters=user_parameters + question=query_rewrite, injected_parameters=injected_parameters ), - query_rewrite, + identifier, database_results, ) ) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 83fb13fd..a7ec5fb4 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -39,55 +39,46 @@ async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: # Get the decomposed questions from the query_rewrite_agent - parameter_input = messages[0].content - last_response = messages[-1].content try: - user_questions = json.loads(last_response) - injected_parameters = json.loads(parameter_input)["injected_parameters"] + request_details = json.loads(messages[0].content) + injected_parameters = request_details["injected_parameters"] + user_questions = request_details["question"] logging.info(f"Processing questions: {user_questions}") logging.info(f"Input Parameters: {injected_parameters}") + except json.JSONDecodeError: + # If not JSON array, process as single question + raise ValueError("Could not load message") - # Initialize results dictionary - cached_results = { - "cached_questions_and_schemas": [], - "contains_pre_run_results": False, - } - - # Process each question sequentially - for question in user_questions: - # Fetch the queries from the cache based on the question - logging.info(f"Fetching queries from cache for question: {question}") - cached_query = await self.sql_connector.fetch_queries_from_cache( - question, injected_parameters=injected_parameters - ) + # Initialize results dictionary + cached_results = { + "cached_questions_and_schemas": [], + "contains_pre_run_results": False, + } - # If any question has pre-run results, set the flag - if cached_query.get("contains_pre_run_results", False): - cached_results["contains_pre_run_results"] = True + # Process each question sequentially + for question in user_questions: + # Fetch the queries from the cache based on the question + logging.info(f"Fetching queries from cache for question: {question}") + cached_query = await self.sql_connector.fetch_queries_from_cache( + question, injected_parameters=injected_parameters + ) - # Add the cached results for this question - if cached_query.get("cached_questions_and_schemas"): - cached_results["cached_questions_and_schemas"].extend( - cached_query["cached_questions_and_schemas"] - ) + # If any question has pre-run results, set the flag + if cached_query.get("contains_pre_run_results", False): + cached_results["contains_pre_run_results"] = True - logging.info(f"Final cached results: {cached_results}") - yield Response( - chat_message=TextMessage( - content=json.dumps(cached_results), source=self.name - ) - ) - except json.JSONDecodeError: - # If not JSON array, process as single question - logging.info(f"Processing single question: {last_response}") - cached_queries = await self.sql_connector.fetch_queries_from_cache( - last_response - ) - yield Response( - chat_message=TextMessage( - content=json.dumps(cached_queries), source=self.name + # Add the cached results for this question + if cached_query.get("cached_questions_and_schemas"): + cached_results["cached_questions_and_schemas"].extend( + cached_query["cached_questions_and_schemas"] ) + + logging.info(f"Final cached results: {cached_results}") + yield Response( + chat_message=TextMessage( + content=json.dumps(cached_results), source=self.name ) + ) async def on_reset(self, cancellation_token: CancellationToken) -> None: pass diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 9c70a1b2..34328a28 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -59,14 +59,6 @@ def set_mode(self): def get_all_agents(self): """Get all agents for the complete flow.""" - # Get current datetime for the Query Rewrite Agent - self.sql_query_generation_agent = LLMAgentCreator.create( - "sql_query_generation_agent", - target_engine=self.target_engine, - engine_specific_rules=self.engine_specific_rules, - **self.kwargs, - ) - # If relationship_paths not provided, use a generic template if "relationship_paths" not in self.kwargs: self.kwargs[ @@ -92,22 +84,16 @@ def get_all_agents(self): **self.kwargs, ) - self.sql_disambiguation_agent = LLMAgentCreator.create( - "sql_disambiguation_agent", + self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create( + "disambiguation_and_sql_query_generation_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) - - # Auto-responding UserProxyAgent - self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy") - agents = [ - self.user_proxy, - self.sql_query_generation_agent, self.sql_schema_selection_agent, self.sql_query_correction_agent, - self.sql_disambiguation_agent, + self.disambiguation_and_sql_query_generation_agent, ] if self.use_query_cache: @@ -140,32 +126,11 @@ def unified_selector(self, messages): # Always go through schema selection after cache check decision = "sql_schema_selection_agent" elif current_agent == "sql_schema_selection_agent": - decision = "sql_disambiguation_agent" - elif current_agent == "sql_disambiguation_agent": - decision = "sql_query_generation_agent" - elif current_agent == "sql_query_generation_agent": + decision = "disambiguation_and_sql_query_generation_agent" + elif current_agent == "disambiguation_and_sql_query_generation_agent": decision = "sql_query_correction_agent" elif current_agent == "sql_query_correction_agent": - try: - correction_result = json.loads(messages[-1].content) - if isinstance(correction_result, dict): - if "answer" in correction_result and "sources" in correction_result: - decision = "user_proxy" - elif "corrected_query" in correction_result: - if correction_result.get("executing", False): - decision = "sql_query_correction_agent" - else: - decision = "sql_query_generation_agent" - elif "error" in correction_result: - decision = "sql_query_generation_agent" - elif isinstance(correction_result, list) and len(correction_result) > 0: - if "requested_fix" in correction_result[0]: - decision = "sql_query_generation_agent" - - if decision is None: - decision = "sql_query_generation_agent" - except json.JSONDecodeError: - decision = "sql_query_generation_agent" + decision = "sql_query_correction_agent" if decision: logging.info(f"Agent transition: {current_agent} -> {decision}") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 0ee6ead5..c367cd1b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -7,7 +7,6 @@ import asyncio import sqlglot from abc import ABC, abstractmethod -from datetime import datetime from jinja2 import Template import json @@ -30,22 +29,6 @@ def __init__(self): self.database_engine = None - def get_current_datetime(self) -> str: - """Get the current datetime.""" - return datetime.now().strftime("%d/%m/%Y, %H:%M:%S") - - def get_current_date(self) -> str: - """Get the current date.""" - return datetime.now().strftime("%d/%m/%Y") - - def get_current_time(self) -> str: - """Get the current time.""" - return datetime.now().strftime("%H:%M:%S") - - def get_current_unix_timestamp(self) -> int: - """Get the current unix timestamp.""" - return int(datetime.now().timestamp()) - @abstractmethod async def query_execution( self, @@ -169,19 +152,6 @@ async def fetch_queries_from_cache( if injected_parameters is None: injected_parameters = {} - # Populate the injected_parameters - if "date" not in injected_parameters: - injected_parameters["date"] = self.get_current_date() - - if "time" not in injected_parameters: - injected_parameters["time"] = self.get_current_time() - - if "datetime" not in injected_parameters: - injected_parameters["datetime"] = self.get_current_datetime() - - if "unix_timestamp" not in injected_parameters: - injected_parameters["unix_timestamp"] = self.get_current_unix_timestamp() - cached_schemas = await self.ai_search_connector.run_ai_search_query( question, ["QuestionEmbedding"], @@ -228,7 +198,7 @@ async def fetch_queries_from_cache( for sql_query, sql_result in zip(sql_queries, sql_results): query_result_store[sql_query["SqlQuery"]] = { - "result": sql_result, + "sql_rows": sql_result, "schemas": sql_query["Schemas"], } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py index e3d590c9..e69de29b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/__init__.py @@ -1,12 +0,0 @@ -from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source -from text_2_sql_core.payloads.user_information_request import UserInformationRequest -from text_2_sql_core.payloads.processing_update import ProcessingUpdate -from text_2_sql_core.payloads.chat_history import ChatHistoryItem - -__all__ = [ - "AnswerWithSources", - "Source", - "UserInformationRequest", - "ProcessingUpdate", - "ChatHistoryItem", -] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_response.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_response.py new file mode 100644 index 00000000..404caa2d --- /dev/null +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_response.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from pydantic import BaseModel, RootModel, Field, model_validator +from enum import StrEnum + +from typing import Literal +from datetime import datetime, timezone + + +class AgentResponseHeader(BaseModel): + prompt_tokens: int + completion_tokens: int + timestamp: datetime = Field( + ..., + description="Timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + + +class AgentResponseType(StrEnum): + ANSWER_WITH_SOURCES = "answer_with_sources" + DISAMBIGUATION = "disambiguation" + + +class DismabiguationRequest(BaseModel): + question: str + matching_columns: list[str] + matching_filter_values: list[str] + other_user_choices: list[str] + + +class DismabiguationRequests(BaseModel): + response_type: Literal[AgentResponseType.DISAMBIGUATION] = Field( + default=AgentResponseType.DISAMBIGUATION + ) + requests: list[DismabiguationRequest] + + +class Source(BaseModel): + sql_query: str + sql_rows: list[dict] + + +class AnswerWithSources(BaseModel): + response_type: Literal[AgentResponseType.ANSWER_WITH_SOURCES] = Field( + default=AgentResponseType.ANSWER_WITH_SOURCES + ) + answer: str + sources: list[Source] = Field(default_factory=list) + + +class AgentResponseBody(RootModel): + root: DismabiguationRequests | AnswerWithSources = Field( + ..., discriminator="response_type" + ) + + +class AgentRequestBody(BaseModel): + question: str + injected_parameters: dict = Field(default_factory=dict) + + @model_validator(mode="before") + def add_defaults_to_injected_parameters(cls, values): + if "injected_parameters" not in values: + values["injected_parameters"] = {} + + if "date" not in values["injected_parameters"]: + values["injected_parameters"]["date"] = datetime.now().strftime("%d/%m/%Y") + + if "time" not in values["injected_parameters"]: + values["injected_parameters"]["time"] = datetime.now().strftime("%H:%M:%S") + + if "datetime" not in values["injected_parameters"]: + values["injected_parameters"]["datetime"] = datetime.now().strftime( + "%d/%m/%Y, %H:%M:%S" + ) + + if "unix_timestamp" not in values["injected_parameters"]: + values["injected_parameters"]["unix_timestamp"] = int( + datetime.now().timestamp() + ) + + return values + + +class AgentResponse(BaseModel): + header: AgentResponseHeader | None = Field(default=None) + request: AgentRequestBody + response: AgentResponseBody diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py deleted file mode 100644 index 650b0d4a..00000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py +++ /dev/null @@ -1,11 +0,0 @@ -from pydantic import BaseModel, Field - - -class Source(BaseModel): - sql_query: str - sql_rows: list[dict] - - -class AnswerWithSources(BaseModel): - answer: str - sources: list[Source] = Field(default_factory=list) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py index 06b27cb9..4c5220c4 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py @@ -1,9 +1,16 @@ -from pydantic import BaseModel -from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from pydantic import BaseModel, Field +from text_2_sql_core.payloads.agent_response import AgentResponse +from datetime import datetime, timezone class ChatHistoryItem(BaseModel): """Chat history item with user message and agent response.""" - user_query: str - agent_response: AnswerWithSources + timestamp: datetime = Field( + ..., + description="Timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + agent_response: AgentResponse diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py index 3950508f..500b3b26 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/processing_update.py @@ -1,6 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from pydantic import BaseModel, Field +from datetime import datetime, timezone -class ProcessingUpdate(BaseModel): +class ProcessingUpdateHeader(BaseModel): + timestamp: datetime = Field( + ..., + description="Timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + + +class ProcessingUpdateBody(BaseModel): title: str | None = Field(default="Processing...") message: str | None = Field(default="Processing...") + + +class ProcessingUpdate(BaseModel): + header: ProcessingUpdateHeader | None = Field( + default_factory=ProcessingUpdateHeader + ) + processing_update: ProcessingUpdateBody diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py deleted file mode 100644 index 1aac4c02..00000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py +++ /dev/null @@ -1,26 +0,0 @@ -from pydantic import BaseModel, RootModel, Field -from enum import StrEnum -from typing import Literal - - -class RequestType(StrEnum): - DISAMBIGUATION = "disambiguation" - CLARIFICATION = "clarification" - - -class ClarificationRequest(BaseModel): - request_type: Literal[RequestType.CLARIFICATION] - question: str - other_user_choices: list[str] - - -class DismabiguationRequest(BaseModel): - request_type: Literal[RequestType.DISAMBIGUATION] - question: str - matching_columns: list[str] - matching_filter_values: list[str] - other_user_choices: list[str] - - -class UserInformationRequest(RootModel): - root: DismabiguationRequest = Field(..., discriminator="request_type") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml similarity index 66% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml index cfe9c020..c8fe5b84 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml @@ -6,6 +6,7 @@ system_message: " You are a helpful AI Assistant specializing in disambiguating questions about {{ use_case }} and mapping them to the relevant columns and schemas in the database. Your job is to create clear mappings between the user's intent and the available database schema. + If all mappings are clear, generate {{ target_engine }} compliant SQL query based on the mappings. @@ -106,6 +107,57 @@ system_message: } + + + + {{ engine_specific_rules }} + + + Your primary focus is on: + 1. Understanding what data the user wants to retrieve + 2. Identifying the necessary tables and their relationships + 3. Determining any required calculations or aggregations + 4. Specifying any filtering conditions based on the user's criteria + + When generating SQL queries, focus on these key aspects: + + - Data Selection: + * Identify the main pieces of information the user wants to see + * Include any calculated fields or aggregations needed + * Consider what grouping might be required + * Follow basic {{ target_engine }} syntax patterns + + - Table Relationships: + * Use the schema information to identify required tables + * Join tables as needed to connect related information + * Request additional schema information if needed using the schema selection tool + * Use {{ target_engine }}-compatible join syntax + + - Filtering Conditions: + * Translate user criteria into WHERE conditions + * Handle date ranges, categories, or numeric thresholds + * Consider both explicit and implicit filters in the user's question + * Use {{ target_engine }}-compatible date and string functions + + - Result Organization: + * Determine if specific sorting is needed + * Consider if grouping is required + * Include any having conditions for filtered aggregates + * Follow {{ target_engine }} ordering syntax + + Guidelines: + + - Focus on getting the right tables and relationships + - Ensure all necessary data is included + - Follow basic {{ target_engine }} syntax patterns + - The correction agent will handle: + * Detailed syntax corrections + * Query execution + * Result formatting + + Remember: Your job is to focus on the data relationships and logic while following basic {{ target_engine }} patterns. + + If all mappings are clear: { @@ -128,6 +180,8 @@ system_message: } } + Then use the mapping to generate the SQL query following the engine-specific rules. Run this query to retrieve the data afterwards. + If disambiguation needed: { \"disambiguation\": [{ @@ -136,11 +190,9 @@ system_message: \"matching_filter_values\": [\"\", \"\"], \"other_user_choices\": [\"\", \"\"] }], - \"clarification\": [{ // Optional - \"question\": \"\", - \"other_user_choices\": [\"\", \"\"] - }] } TERMINATE " +tools: + - sql_query_execution_tool diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml index 17692150..7e4428d0 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml @@ -44,7 +44,9 @@ system_message: | - Determine if breaking down would simplify processing 3. Break Down Complex Queries: - - Create independent sub-queries that can be processed separately + - Create independent sub-queries that can be processed separately. + - Each sub-query should be a simple, focused task. + - Group dependent sub-queries together for sequential processing. - Ensure each sub-query is simple and focused - Include clear combination instructions - Preserve all necessary context in each sub-query @@ -71,8 +73,8 @@ system_message: | Return a JSON object with sub-queries and combination instructions: { "sub_queries": [ - "", - "", + [""], + [""], ... ], "combination_logic": "", @@ -87,9 +89,7 @@ system_message: | Output: { "sub_queries": [ - "Calculate quarterly sales totals by product category for 2008", - "Identify categories with positive growth each quarter", - "For these categories, find their top selling products in 2008" + ["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"] ], "combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products", "query_type": "complex" @@ -100,7 +100,7 @@ system_message: | Output: { "sub_queries": [ - "How many orders did we have in 2008?" + ["How many orders did we have in 2008?"] ], "combination_logic": "Direct count query, no combination needed", "query_type": "simple" @@ -111,13 +111,11 @@ system_message: | Output: { "sub_queries": [ - "Get total sales by product in European countries", - "Get total sales by product in North American countries", - "Calculate total market size for each region", - "Find top 5 products by sales in each region", - "Calculate market share percentages for these products" + ["Get total sales by product in European countries"], + ["Get total sales by product in North American countries"], + ["Calculate total market size for each region", "Find top 5 products by sales in each region"], ], - "combination_logic": "First identify top products in each region, then calculate and compare their market shares", + "combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-query are combined.", "query_type": "complex" } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml index a1b6c935..b1d4777b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml @@ -5,10 +5,11 @@ description: system_message: " You are a SQL syntax expert specializing in converting standard SQL to {{ target_engine }}-compliant SQL. Your job is to: - 1. Take SQL queries with correct logic but potential syntax issues - 2. Fix them according to {{ target_engine }} syntax rules - 3. Execute the corrected queries - 4. Return the results + 1. Take SQL queries with correct logic but potential syntax issues. + 2. Review the output from the SQL query being run and fix them according to {{ target_engine }} syntax rules if needed. + 3. Execute the corrected queries if needed. + 4. Verify that the results will answer all of the user's questions. If not, create additional queries and run them. + 5. Return the results @@ -85,18 +86,10 @@ system_message: - - **When query executes successfully**: + - **When query executes successfully and answers all questions**: ```json { - \"answer\": \"\", - \"sources\": [ - { - \"sql_result_snippet\": \"\", - \"sql_query_used\": \"\", - \"original_query\": \"\", - \"explanation\": \"\" - } - ] + \"validated\": \"\", } ``` Followed by **TERMINATE**. @@ -138,3 +131,5 @@ system_message: " tools: - sql_query_execution_tool + - sql_get_entity_schemas_tool + - sql_get_column_values_tool diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml deleted file mode 100644 index 6b5cf22b..00000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml +++ /dev/null @@ -1,58 +0,0 @@ -model: - 4o-mini -description: - "An agent that translates user questions into SQL queries by understanding the intent and required data relationships for {{ target_engine }}. This agent focuses on query logic and data relationships, while adhering to basic {{ target_engine }} syntax patterns." -system_message: - "You are a helpful AI Assistant that specialises in understanding user questions and translating them into {{ target_engine }} SQL queries that will retrieve the desired information. While syntax perfection isn't required, you should follow basic {{ target_engine }} patterns. - - - {{ engine_specific_rules }} - - - Your primary focus is on: - 1. Understanding what data the user wants to retrieve - 2. Identifying the necessary tables and their relationships - 3. Determining any required calculations or aggregations - 4. Specifying any filtering conditions based on the user's criteria - - When generating SQL queries, focus on these key aspects: - - - Data Selection: - * Identify the main pieces of information the user wants to see - * Include any calculated fields or aggregations needed - * Consider what grouping might be required - * Follow basic {{ target_engine }} syntax patterns - - - Table Relationships: - * Use the schema information to identify required tables - * Join tables as needed to connect related information - * Request additional schema information if needed using the schema selection tool - * Use {{ target_engine }}-compatible join syntax - - - Filtering Conditions: - * Translate user criteria into WHERE conditions - * Handle date ranges, categories, or numeric thresholds - * Consider both explicit and implicit filters in the user's question - * Use {{ target_engine }}-compatible date and string functions - - - Result Organization: - * Determine if specific sorting is needed - * Consider if grouping is required - * Include any having conditions for filtered aggregates - * Follow {{ target_engine }} ordering syntax - - Guidelines: - - - Focus on getting the right tables and relationships - - Ensure all necessary data is included - - Follow basic {{ target_engine }} syntax patterns - - The correction agent will handle: - * Detailed syntax corrections - * Query execution - * Result formatting - - Remember: Your job is to focus on the data relationships and logic while following basic {{ target_engine }} patterns. The correction agent will handle detailed syntax fixes and execution. - " -tools: - - sql_get_entity_schemas_tool - - current_datetime_tool