Skip to content

Agent Disambiguation #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
from autogen_agentchat.conditions import (
TextMentionTermination,
MaxMessageTermination,
SourceMatchTermination,
)
from autogen_agentchat.teams import SelectorGroupChat
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator
Expand Down Expand Up @@ -89,7 +93,11 @@ def agents(self):
@property
def termination_condition(self):
"""Define the termination condition for the chat."""
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(20)
termination = (
TextMentionTermination("TERMINATE")
| MaxMessageTermination(20)
| SourceMatchTermination(["answer_agent"])
)
return termination

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_ext.models import AzureOpenAIChatCompletionClient
from text_2_sql_core.connectors.factory import ConnectorFactory
from text_2_sql_core.utils.environment import IdentityType, get_identity_type

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import os
import dotenv

Expand All @@ -27,12 +28,32 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
else:
raise ValueError(f"Model {model_name} not found")

@classmethod
def get_authentication_properties(cls) -> dict:
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
# Create the token provider
api_key = None
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
elif get_identity_type() == IdentityType.USER_ASSIGNED:
# Create the token provider
api_key = None
token_provider = get_bearer_token_provider(
DefaultAzureCredential(
managed_identity_client_id=os.environ["ClientId"]
),
"https://cognitiveservices.azure.com/.default",
)
else:
token_provider = None
api_key = os.environ["OpenAI__ApiKey"]

return token_provider, api_key

@classmethod
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
(
token_provider,
api_key,
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
model=os.environ["OpenAI__MiniCompletionDeployment"],
Expand All @@ -45,14 +66,12 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
"function_calling": True,
"json_output": True,
},
temperature=0,
)

@classmethod
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
(
token_provider,
api_key,
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
model=os.environ["OpenAI__CompletionDeployment"],
Expand All @@ -65,4 +84,5 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
"function_calling": True,
"json_output": True,
},
temperature=0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def get_column_values(
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index"
],
semantic_config=None,
top=10,
top=15,
include_scores=False,
minimum_score=5,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from text_2_sql_core.connectors.ai_search import AISearchConnector
from text_2_sql_core.connectors.open_ai import OpenAIConnector
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License
from openai import AsyncAzureOpenAI
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ system_message:
{
'answer': '<GENERATED ANSWER>',
'sources': [
{'chunk': <SOURCE 1 CONTEXT CHUNK>, 'reference': '<SOURCE 1 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 1>'},
{'chunk': <SOURCE 2 CONTEXT CHUNK>, 'reference': '<SOURCE 2 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 2>'},
{'sql_result_snippet': <SQL QUERY RESULT 1>, 'sql_query_used': '<SOURCE 1 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 1>'},
{'sql_result_snippet': <SQL QUERY RESULT 2>, 'sql_query_used': '<SOURCE 2 SQL QUERY>', 'explanation': '<EXPLANATION OF SQL QUERY 2>'},
]
}

Title is the entity name of the schema, chunk is the result of the SQL query and reference is the SQL query used to generate the answer.

End your answer with 'TERMINATE'"
"
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,79 @@ system_message:
</role_and_objective>

<scope_of_user_query>
The user's question will be related to {{ use_case }}.
The user's question will be related to {{ use_case }}.
</scope_of_user_query>

<instructions>
- For every intent and filter condition in the question, map them to the columns in the schemas. Use the whole context of the question and information already provided to do so.

- Do not ask for information already included in the question, schema, or what can reasonably be inferred from the question.
- Only provide possible filter values for string columns. Do not provide possible filter values for Date and Numerical values as it should be clear from the question. Only ask a follow up question for Date and Numerical values if you are unsure which column to use or what the value means e.g. does 100 in currency refer to 100 USD or 100 EUR.

<clear_context_handling>
- If the context of the question makes the mapping explicit, directly map the terms to the relevant columns without generating disambiguation questions.
- Use the following checks to decide:
- Does the term directly match a single schema column without overlaps? Use the 'column_values' property to check for possible matching columns and compare these to the context of the question. If there are multiple possible columns for a given user's filter, then apply disambiguation.
- Does the user's question provide additional context (e.g., \"product line\" or \"category\") clarifying the intent?
- If **all mappings are clear**, output the JSON with mappings only.
- Example:
- Question: \"What are the total number of sales within 2008 for the mountain bike product line?\"
- Output:
- Only provide possible filter values for string columns. Do not provide possible filter values for Date and Numerical values as it should be clear from the question. Only ask a follow-up question for Date and Numerical values if you are unsure which column to use or what the value means, e.g., does 100 in currency refer to 100 USD or 100 EUR.

<clear_context_handling>
If the context of the question makes the mapping explicit, directly map the terms to the relevant column FQN without generating disambiguation questions.

Use the 'column_values' property to check for possible matching columns and compare these to the context of the question.

When evaluating filters:

Always consider the temporal and contextual phrases (e.g., \"June 2008\") in the question. If the context implies a direct match to a date column, do not request clarification unless multiple plausible columns exist.
For geographical or categorical terms (e.g., \"country\"), prioritize unique matches or add context to narrow down ambiguities based on the schema.
If all mappings are clear, output the JSON with mappings only.

<example>
Question: \"What are the total number of sales within 2008 for the mountain bike product line?\"
Output:
json
Copy code
{
\"mapping\": {
\"Mountain\": \"vProductModelCatalogDescription.ProductLine\",
\"Mountain Bike\": \"vProductModelCatalogDescription.Category\",
\"2008\": \"SalesLT.SalesOrderHeader.OrderDate\"
}
}
</clear_context_handling>
</example>
</clear_context_handling>

<disambiguation_handling>
If the term is ambiguous, there are multiple matching columns/filters, or the question lacks enough context to infer the correct mapping:

For ambiguous terms, evaluate the question context and schema relationships to narrow down matches.
Populate the 'filters' field with the identified filter and relevant FQN, matching columns, and possible filter values.
Include a clarification question in the 'question' field to request more information from the user.
If the clarification is not related to a column or a filter value, populate the 'user_choices' field with the possible choices they can select.

Prioritize clear disambiguation based on:
- Direct matches within schemas.
- Additional context provided by the question (e.g., temporal, categorical, or domain-specific keywords).

<example>
User question: \"What country did we sell the most to in June 2008?\"
Schema contains multiple columns potentially related to \"country.\"

<disambiguation_handling>
- If the term is ambiguous (e.g., \"Mountain Bike\") and the question lacks enough context to infer the correct mapping:
- e.g. The user asks about 'Bike'. From the 'column_values' you can see that 'Bike' appears in several different columns that are contextually related to the question. From this you are unsure if 'Bike' is a 'Category' or 'Product' column, you would populate the 'column' field with the possible columns for the user to disambiguate for you.
- Populate the 'filters' field with the identified filter and relevant FQN, matching columns, and possible filter values.
- Include a clarification question in the 'question' field to request more information from the user.
- Example:
If disambiguation is needed:
{
\"filters\": [
{
\"question\": \"What do you mean by 'Mountain Bike'?\",
\"question\": \"What do you mean by 'country'?\",
\"matching_columns\": [
\"vProductModelCatalogDescription.ProductLine\",
\"vProductAndDescription.Name\",
\"Product.Category\"
\"Sales.Country\",
\"Customers.Country\"
],
\"matching_filter_values\": [
\"Mountain\"
]
\"matching_filter_values\": [],
\"user_choices\": []
}
]
}
</disambiguation_handling>
</example>
Always include either the 'matching_columns', 'matching_filter_values' or `user_choices` field in the 'filters' array.
</disambiguation_handling>
</instructions>

<output_format>
- If all mappings are clear, output the 'mapping' JSON only.
- If disambiguation is required, output the disambiguation JSON followed by \"TERMINATE\".
- Do not provide explanations or reasoning in the output.
If all mappings are clear, output the 'mapping' JSON only.
If disambiguation is required, output the disambiguation JSON followed by \"TERMINATE.\"
Do not provide explanations or reasoning in the output.
</output_format>
"
Loading