Skip to content

Text2SQL Limit Clause Fix #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions text_2_sql/autogen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ authors = [
requires-python = ">=3.11"
dependencies = [
"aiostream>=0.6.4",
"autogen-agentchat==0.4.5",
"autogen-core==0.4.5",
"autogen-ext[azure,openai]==0.4.5",
"autogen-agentchat==0.4.7",
"autogen-core==0.4.7",
"autogen-ext[azure,openai]==0.4.7",
"grpcio>=1.68.1",
"pyyaml>=6.0.2",
"text_2_sql_core",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,7 @@ def termination_condition(self):
termination = (
SourceMatchTermination("answer_agent")
| SourceMatchTermination("answer_with_follow_up_suggestions_agent")
# | TextMentionTermination(
# "[]",
# sources=["user_message_rewrite_agent"],
# )
| TextMentionTermination(
"contains_disambiguation_requests",
sources=["parallel_query_solving_agent"],
)
| TextMentionTermination("contains_disambiguation_requests")
| MaxMessageTermination(5)
)
return termination
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from aiostream import stream
from json import JSONDecodeError
import re
import os
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -226,23 +225,12 @@ async def consume_inner_messages_from_agentic_flow(
# Create an instance of the InnerAutoGenText2Sql class
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)

# Add database connection info to injected parameters
query_params = injected_parameters.copy() if injected_parameters else {}
if "Text2Sql__Tsql__ConnectionString" in os.environ:
query_params["database_connection_string"] = os.environ[
"Text2Sql__Tsql__ConnectionString"
]
if "Text2Sql__Tsql__Database" in os.environ:
query_params["database_name"] = os.environ[
"Text2Sql__Tsql__Database"
]

# Launch tasks for each sub-query
inner_solving_generators.append(
consume_inner_messages_from_agentic_flow(
inner_autogen_text_2_sql.process_user_message(
user_message=parallel_message,
injected_parameters=query_params,
injected_parameters=injected_parameters,
database_results=filtered_parallel_messages.database_results,
),
parallel_message,
Expand Down Expand Up @@ -294,7 +282,7 @@ async def consume_inner_messages_from_agentic_flow(
),
)

break
return

# Final response
yield Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,6 @@ def __init__(self, **kwargs: dict):
self.kwargs = kwargs
self.set_mode()

# Store original environment variables
self.original_db_conn = os.environ.get("Text2Sql__Tsql__ConnectionString")
self.original_db_name = os.environ.get("Text2Sql__Tsql__Database")

def _update_environment(self, injected_parameters: dict = None):
"""Update environment variables with injected parameters."""
if injected_parameters:
if "database_connection_string" in injected_parameters:
os.environ["Text2Sql__Tsql__ConnectionString"] = injected_parameters[
"database_connection_string"
]
if "database_name" in injected_parameters:
os.environ["Text2Sql__Tsql__Database"] = injected_parameters[
"database_name"
]

def _restore_environment(self):
"""Restore original environment variables."""
if self.original_db_conn:
os.environ["Text2Sql__Tsql__ConnectionString"] = self.original_db_conn
if self.original_db_name:
os.environ["Text2Sql__Tsql__Database"] = self.original_db_name

def set_mode(self):
"""Set the mode of the plugin based on the environment variables."""
self.pre_run_query_cache = (
Expand Down Expand Up @@ -195,19 +172,12 @@ def process_user_message(
"""
logging.info("Processing question: %s", user_message)

# Update environment with injected parameters
self._update_environment(injected_parameters)

try:
agent_input = {
"user_message": user_message,
"injected_parameters": injected_parameters,
}
agent_input = {
"user_message": user_message,
"injected_parameters": injected_parameters,
}

if database_results:
agent_input["database_results"] = database_results
if database_results:
agent_input["database_results"] = database_results

return self.agentic_flow.run_stream(task=json.dumps(agent_input))
finally:
# Restore original environment
self._restore_environment()
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,22 @@ def handle_node(node):
current_limit = parsed_query.args.get("limit")
logging.debug("Current Limit: %s", current_limit)

if current_limit is None or current_limit.value > self.row_limit:
# More defensive check to handle different structures
should_add_limit = True
if current_limit is not None:
try:
if hasattr(current_limit, "expression") and hasattr(
current_limit.expression, "value"
):
if current_limit.expression.value <= self.row_limit:
should_add_limit = False
except AttributeError:
logging.warning("Unexpected limit structure: %s", current_limit)

if should_add_limit:
# Create a new LIMIT expression
limit_expr = Limit(expression=Literal.number(self.row_limit))

# Attach it to the query by setting it on the SELECT expression
# Attach it to the query
parsed_query.set("limit", limit_expr)
updated_parsed_queries.append(
parsed_query.sql(dialect=self.database_engine.value.lower())
Expand Down
24 changes: 12 additions & 12 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading