diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 7c08e4b..ebe5b5f 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -9,9 +9,9 @@ authors = [ requires-python = ">=3.11" dependencies = [ "aiostream>=0.6.4", - "autogen-agentchat==0.4.5", - "autogen-core==0.4.5", - "autogen-ext[azure,openai]==0.4.5", + "autogen-agentchat==0.4.7", + "autogen-core==0.4.7", + "autogen-ext[azure,openai]==0.4.7", "grpcio>=1.68.1", "pyyaml>=6.0.2", "text_2_sql_core", diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 4ea0bf5..bdd39fb 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -83,14 +83,7 @@ def termination_condition(self): termination = ( SourceMatchTermination("answer_agent") | SourceMatchTermination("answer_with_follow_up_suggestions_agent") - # | TextMentionTermination( - # "[]", - # sources=["user_message_rewrite_agent"], - # ) - | TextMentionTermination( - "contains_disambiguation_requests", - sources=["parallel_query_solving_agent"], - ) + | TextMentionTermination("contains_disambiguation_requests") | MaxMessageTermination(5) ) return termination diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py index 97a9e00..af61a6e 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py @@ -17,7 +17,6 @@ from aiostream import stream from json import JSONDecodeError import re -import os from pydantic import BaseModel, Field @@ -226,23 +225,12 @@ async def consume_inner_messages_from_agentic_flow( # Create an instance of the InnerAutoGenText2Sql class inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs) - # Add database connection info to injected parameters - query_params = injected_parameters.copy() if injected_parameters else {} - if "Text2Sql__Tsql__ConnectionString" in os.environ: - query_params["database_connection_string"] = os.environ[ - "Text2Sql__Tsql__ConnectionString" - ] - if "Text2Sql__Tsql__Database" in os.environ: - query_params["database_name"] = os.environ[ - "Text2Sql__Tsql__Database" - ] - # Launch tasks for each sub-query inner_solving_generators.append( consume_inner_messages_from_agentic_flow( inner_autogen_text_2_sql.process_user_message( user_message=parallel_message, - injected_parameters=query_params, + injected_parameters=injected_parameters, database_results=filtered_parallel_messages.database_results, ), parallel_message, @@ -294,7 +282,7 @@ async def consume_inner_messages_from_agentic_flow( ), ) - break + return # Final response yield Response( diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py index 818b7d3..8cdbdc7 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py @@ -44,29 +44,6 @@ def __init__(self, **kwargs: dict): self.kwargs = kwargs self.set_mode() - # Store original environment variables - self.original_db_conn = os.environ.get("Text2Sql__Tsql__ConnectionString") - self.original_db_name = os.environ.get("Text2Sql__Tsql__Database") - - def _update_environment(self, injected_parameters: dict = None): - """Update environment variables with injected parameters.""" - if injected_parameters: - if "database_connection_string" in injected_parameters: - os.environ["Text2Sql__Tsql__ConnectionString"] = injected_parameters[ - "database_connection_string" - ] - if "database_name" in injected_parameters: - os.environ["Text2Sql__Tsql__Database"] = injected_parameters[ - "database_name" - ] - - def _restore_environment(self): - """Restore original environment variables.""" - if self.original_db_conn: - os.environ["Text2Sql__Tsql__ConnectionString"] = self.original_db_conn - if self.original_db_name: - os.environ["Text2Sql__Tsql__Database"] = self.original_db_name - def set_mode(self): """Set the mode of the plugin based on the environment variables.""" self.pre_run_query_cache = ( @@ -195,19 +172,12 @@ def process_user_message( """ logging.info("Processing question: %s", user_message) - # Update environment with injected parameters - self._update_environment(injected_parameters) - - try: - agent_input = { - "user_message": user_message, - "injected_parameters": injected_parameters, - } + agent_input = { + "user_message": user_message, + "injected_parameters": injected_parameters, + } - if database_results: - agent_input["database_results"] = database_results + if database_results: + agent_input["database_results"] = database_results - return self.agentic_flow.run_stream(task=json.dumps(agent_input)) - finally: - # Restore original environment - self._restore_environment() + return self.agentic_flow.run_stream(task=json.dumps(agent_input)) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index e31588c..af1e61b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -311,11 +311,22 @@ def handle_node(node): current_limit = parsed_query.args.get("limit") logging.debug("Current Limit: %s", current_limit) - if current_limit is None or current_limit.value > self.row_limit: + # More defensive check to handle different structures + should_add_limit = True + if current_limit is not None: + try: + if hasattr(current_limit, "expression") and hasattr( + current_limit.expression, "value" + ): + if current_limit.expression.value <= self.row_limit: + should_add_limit = False + except AttributeError: + logging.warning("Unexpected limit structure: %s", current_limit) + + if should_add_limit: # Create a new LIMIT expression limit_expr = Limit(expression=Literal.number(self.row_limit)) - - # Attach it to the query by setting it on the SELECT expression + # Attach it to the query parsed_query.set("limit", limit_expr) updated_parsed_queries.append( parsed_query.sql(dialect=self.database_engine.value.lower()) diff --git a/uv.lock b/uv.lock index ba4dbd3..f2f0f31 100644 --- a/uv.lock +++ b/uv.lock @@ -261,19 +261,19 @@ wheels = [ [[package]] name = "autogen-agentchat" -version = "0.4.5" +version = "0.4.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "autogen-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/69/c1b511be9a0bc7c65d75266e8c89cd355984612418b581fcafa5c255460e/autogen_agentchat-0.4.5.tar.gz", hash = "sha256:a8d5493b4ec6c45f4d40c33c6d3bb98a6409f7a6428ce4fa9645e51bfe2d7408", size = 59454 } +sdist = { url = "https://files.pythonhosted.org/packages/ec/37/a3083794ee00d3abc51ca24adf009b96f0896def06c0a41f9d8a646fd61f/autogen_agentchat-0.4.7.tar.gz", hash = "sha256:83ab5050e1983e64bddb46a9f10048b315e17bd2a20f083bb5d99844ed36fa7f", size = 62274 } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/51/8182c314dc94cf8d70aa02377c391f3f2c026e88ec33080f230e92ecac4d/autogen_agentchat-0.4.5-py3-none-any.whl", hash = "sha256:76f6fff7ae1ec4eb34f437df6a2a781995c4ed6679e7a6297cfcd89ff9d7791f", size = 64160 }, + { url = "https://files.pythonhosted.org/packages/64/80/6d5b1ecf5e2288f98d8942d57f9442e11eacf3e761517fd9174f370372d8/autogen_agentchat-0.4.7-py3-none-any.whl", hash = "sha256:542fefbfb0cd382c551ced715a93bdf845d8d510b4be85a061550c0a7a08a42f", size = 66229 }, ] [[package]] name = "autogen-core" -version = "0.4.5" +version = "0.4.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonref" }, @@ -283,21 +283,21 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/b5/68baeba75e2eb3870e845869c05ae26840b417f0ef7adfb307a3929fd61f/autogen_core-0.4.5.tar.gz", hash = "sha256:dbe09ba585bef18a099bfbcc494385cb383085633eea9e3fd25d0d39393a53be", size = 2314264 } +sdist = { url = "https://files.pythonhosted.org/packages/c7/cf/4b3cdc3fc6c5296344dfacb46fd290562da6fe0f678c3417cacf65fe638a/autogen_core-0.4.7.tar.gz", hash = "sha256:51955b3d7b7b43583373627f25af5cc1214d358350c9fd8a9ded74065ff3ca93", size = 2376802 } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/50/dae1ed34c7e964c04927ac618b342f7cc8fa05dfd90624bb06fa17693898/autogen_core-0.4.5-py3-none-any.whl", hash = "sha256:99b5b0217d3bd4dc317e2ff49ee0340f09a4ac42337e6b227651512d8eb31e9a", size = 78858 }, + { url = "https://files.pythonhosted.org/packages/3a/ad/d2bac3d2adb11388b8147bc682de8dabcb82cf8d774a9f13e43d152e9ab8/autogen_core-0.4.7-py3-none-any.whl", hash = "sha256:31a73b68369c7fa4006e349aedb909c198d887646b0443810d8d5a4694df44bf", size = 81043 }, ] [[package]] name = "autogen-ext" -version = "0.4.5" +version = "0.4.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "autogen-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/63/f3/15fa7c699f94130acaa83cc4a32a51359e013b5f5227b4503905f62fdbc5/autogen_ext-0.4.5.tar.gz", hash = "sha256:ccee093dcd7bcd979d6dc7b1f33d5085747aeaeafee08766e2e9fd15e520c852", size = 143566 } +sdist = { url = "https://files.pythonhosted.org/packages/25/86/28da8045f05d00c095cc592e0a39bc66461aef684bbd57c0bf770588a633/autogen_ext-0.4.7.tar.gz", hash = "sha256:461f992ffebed87075c02d3be7f3b8355262bcf507d012e084eac231ea951b1a", size = 164974 } wheels = [ - { url = "https://files.pythonhosted.org/packages/18/a2/949976cbf6d46b28dd97f0e17604e3e021d657c2a61f21220014fdd0cf83/autogen_ext-0.4.5-py3-none-any.whl", hash = "sha256:1b8441777d9ccce36cd49d835c2677faa891c745f45933e37bcc17c8d4ff187e", size = 144923 }, + { url = "https://files.pythonhosted.org/packages/1c/40/5b7460d5fbf52f18da7c77ff8a6e5186d63d6ad4c8dadef8e8b851ded8a0/autogen_ext-0.4.7-py3-none-any.whl", hash = "sha256:58f23e9e0127de2863ee40fc5bcc64a65596775813dedf44f5765dc17f1d9206", size = 161872 }, ] [package.optional-dependencies] @@ -358,9 +358,9 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiostream", specifier = ">=0.6.4" }, - { name = "autogen-agentchat", specifier = "==0.4.5" }, - { name = "autogen-core", specifier = "==0.4.5" }, - { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.5" }, + { name = "autogen-agentchat", specifier = "==0.4.7" }, + { name = "autogen-core", specifier = "==0.4.7" }, + { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.7" }, { name = "azure-cosmos", specifier = ">=4.9.0" }, { name = "cachetools", specifier = ">=5.5.1" }, { name = "grpcio", specifier = ">=1.68.1" },