Skip to content

Commit bac5566

Browse files
Add automatic row limit (#160)
1 parent 3d92585 commit bac5566

File tree

6 files changed

+79
-24
lines changed

6 files changed

+79
-24
lines changed

text_2_sql/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Text2Sql__UseQueryCache=<Determines if the Query Cache will be used to speed up
66
Text2Sql__PreRunQueryCache=<Determines if the results from the Query Cache will be pre-run to speed up answer generation. Defaults to True.> # True or False
77
Text2Sql__UseColumnValueStore=<Determines if the Column Value Store will be used for schema selection Defaults to True.> # True or False
88
Text2Sql__GenerateFollowUpSuggestions=<Determines if follow up questions will be generated. Defaults to True.> # True or False
9+
Text2Sql__RowLimit=<Determines the maximum number of rows that will be returned in a query. Defaults to 100.> # Integer
910

1011
# Open AI Connection Details
1112
OpenAI__CompletionDeployment=<openAICompletionDeploymentId. Used for data dictionary creator>

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def set_mode(self):
7979
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
8080
)
8181

82+
# Set the row limit
83+
self.kwargs["row_limit"] = int(os.environ.get("Text2Sql__RowLimit", 100))
84+
8285
def get_all_agents(self):
8386
"""Get all agents for the complete flow."""
8487
# If relationship_paths not provided, use a generic template
@@ -93,31 +96,31 @@ def get_all_agents(self):
9396
- Entity → Attributes (for entity-specific analysis)
9497
"""
9598

96-
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
99+
sql_schema_selection_agent = SqlSchemaSelectionAgent(
97100
target_engine=self.target_engine,
98101
**self.kwargs,
99102
)
100103

101-
self.sql_query_correction_agent = LLMAgentCreator.create(
104+
sql_query_correction_agent = LLMAgentCreator.create(
102105
"sql_query_correction_agent",
103106
target_engine=self.target_engine,
104107
**self.kwargs,
105108
)
106109

107-
self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
110+
disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
108111
"disambiguation_and_sql_query_generation_agent",
109112
target_engine=self.target_engine,
110113
**self.kwargs,
111114
)
112115
agents = [
113-
self.sql_schema_selection_agent,
114-
self.sql_query_correction_agent,
115-
self.disambiguation_and_sql_query_generation_agent,
116+
sql_schema_selection_agent,
117+
sql_query_correction_agent,
118+
disambiguation_and_sql_query_generation_agent,
116119
]
117120

118121
if self.use_query_cache:
119-
self.query_cache_agent = SqlQueryCacheAgent()
120-
agents.append(self.query_cache_agent)
122+
query_cache_agent = SqlQueryCacheAgent()
123+
agents.append(query_cache_agent)
121124

122125
return agents
123126

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from text_2_sql_core.connectors.factory import ConnectorFactory
77
import asyncio
88
import sqlglot
9-
from sqlglot.expressions import Parameter, Select, Identifier
9+
from sqlglot.expressions import Parameter, Select, Identifier, Literal, Limit
1010
from abc import ABC, abstractmethod
1111
from jinja2 import Template
1212
import json
@@ -30,6 +30,9 @@ def __init__(self):
3030
os.environ.get("Text2Sql__UseAISearch", "True").lower() == "true"
3131
)
3232

33+
# Set the row limit
34+
self.row_limit = int(os.environ.get("Text2Sql__RowLimit", 100))
35+
3336
# Only initialize AI Search connector if enabled
3437
self.ai_search_connector = (
3538
ConnectorFactory.get_ai_search_connector() if self.use_ai_search else None
@@ -195,7 +198,9 @@ async def query_execution_with_limit(
195198
) = await self.query_validation(sql_query)
196199

197200
if validation_result and validation_errors is None:
198-
result = await self.query_execution(cleaned_query, cast_to=None, limit=25)
201+
result = await self.query_execution(
202+
cleaned_query, cast_to=None, limit=self.row_limit
203+
)
199204

200205
return json.dumps(
201206
{
@@ -275,11 +280,13 @@ def handle_node(node):
275280
identifiers.append(node.this)
276281

277282
detected_invalid_identifiers = []
283+
updated_parsed_queries = []
278284

279285
for parsed_query in parsed_queries:
280286
for node in parsed_query.walk():
281287
handle_node(node)
282288

289+
# check for invalid identifiers
283290
for token in expressions + identifiers:
284291
if isinstance(token, Parameter):
285292
identifier = str(token.this.this).upper()
@@ -298,12 +305,32 @@ def handle_node(node):
298305
logging.error(error_message)
299306
return False, None, error_message
300307

308+
# Add a limit clause to the query if it doesn't already have one
309+
for parsed_query in parsed_queries:
310+
# Add a limit clause to the query if it doesn't already have one
311+
current_limit = parsed_query.args.get("limit")
312+
logging.debug("Current Limit: %s", current_limit)
313+
314+
if current_limit is None or current_limit.value > self.row_limit:
315+
# Create a new LIMIT expression
316+
limit_expr = Limit(expression=Literal.number(self.row_limit))
317+
318+
# Attach it to the query by setting it on the SELECT expression
319+
parsed_query.set("limit", limit_expr)
320+
updated_parsed_queries.append(
321+
parsed_query.sql(dialect=self.database_engine.value.lower())
322+
)
323+
else:
324+
updated_parsed_queries.append(
325+
parsed_query.sql(dialect=self.database_engine.value.lower())
326+
)
327+
301328
except sqlglot.errors.ParseError as e:
302329
logging.error("SQL Query is invalid: %s", e.errors)
303330
return False, None, e.errors
304331
else:
305332
logging.info("SQL Query is valid.")
306-
return True, cleaned_query, None
333+
return True, ";".join(updated_parsed_queries), None
307334

308335
async def fetch_sql_queries_with_schemas_from_cache(
309336
self, question: str, injected_parameters: dict = None

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,30 @@
1616

1717

1818
class PayloadSource(StrEnum):
19+
"""Payload source enum."""
20+
1921
USER = "user"
2022
ASSISTANT = "assistant"
2123

2224

2325
class PayloadType(StrEnum):
26+
"""Payload type enum."""
27+
2428
ANSWER_WITH_SOURCES = "answer_with_sources"
2529
DISAMBIGUATION_REQUESTS = "disambiguation_requests"
2630
PROCESSING_UPDATE = "processing_update"
2731
USER_MESSAGE = "user_message"
2832

2933

30-
class InteractionPayloadBase(BaseModel):
34+
class PayloadAndBodyBase(BaseModel):
35+
"""Base class for payloads and bodies."""
36+
3137
model_config = ConfigDict(populate_by_name=True, extra="ignore")
3238

3339

34-
class PayloadBase(InteractionPayloadBase):
40+
class PayloadBase(PayloadAndBodyBase):
41+
"""Base class for payloads."""
42+
3543
message_id: str = Field(
3644
..., default_factory=lambda: str(uuid4()), alias="messageId"
3745
)
@@ -42,12 +50,14 @@ class PayloadBase(InteractionPayloadBase):
4250
payload_type: PayloadType = Field(..., alias="payloadType")
4351
payload_source: PayloadSource = Field(..., alias="payloadSource")
4452

45-
body: InteractionPayloadBase | None = Field(default=None)
53+
body: PayloadAndBodyBase | None = Field(default=None)
54+
4655

56+
class DismabiguationRequestsPayload(PayloadAndBodyBase):
57+
"""Disambiguation requests payload. Handles requests for the end user to response to"""
4758

48-
class DismabiguationRequestsPayload(InteractionPayloadBase):
49-
class Body(InteractionPayloadBase):
50-
class DismabiguationRequest(InteractionPayloadBase):
59+
class Body(PayloadAndBodyBase):
60+
class DismabiguationRequest(PayloadAndBodyBase):
5161
assistant_question: str | None = Field(..., alias="assistantQuestion")
5262
user_choices: list[str] | None = Field(default=None, alias="userChoices")
5363

@@ -65,16 +75,19 @@ class DismabiguationRequest(InteractionPayloadBase):
6575
body: Body | None = Field(default=None)
6676

6777
def __init__(self, **kwargs):
78+
"""Custom init method to pass kwargs to the body."""
6879
super().__init__(**kwargs)
6980

7081
body_kwargs = kwargs.get("body", kwargs)
7182

7283
self.body = self.Body(**body_kwargs)
7384

7485

75-
class AnswerWithSourcesPayload(InteractionPayloadBase):
76-
class Body(InteractionPayloadBase):
77-
class Source(InteractionPayloadBase):
86+
class AnswerWithSourcesPayload(PayloadAndBodyBase):
87+
"""Answer with sources payload. Handles the answer and sources for the answer. The follow up suggestion property is optional and may be used to provide the user with a follow up suggestion."""
88+
89+
class Body(PayloadAndBodyBase):
90+
class Source(PayloadAndBodyBase):
7891
sql_query: str = Field(alias="sqlQuery")
7992
sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows")
8093

@@ -94,15 +107,18 @@ class Source(InteractionPayloadBase):
94107
body: Body | None = Field(default=None)
95108

96109
def __init__(self, **kwargs):
110+
"""Custom init method to pass kwargs to the body."""
97111
super().__init__(**kwargs)
98112

99113
body_kwargs = kwargs.get("body", kwargs)
100114

101115
self.body = self.Body(**body_kwargs)
102116

103117

104-
class ProcessingUpdatePayload(InteractionPayloadBase):
105-
class Body(InteractionPayloadBase):
118+
class ProcessingUpdatePayload(PayloadAndBodyBase):
119+
"""Processing update payload. Handles updates to the user on the processing status."""
120+
121+
class Body(PayloadAndBodyBase):
106122
title: str | None = "Processing..."
107123
message: str | None = "Processing..."
108124

@@ -115,15 +131,18 @@ class Body(InteractionPayloadBase):
115131
body: Body | None = Field(default=None)
116132

117133
def __init__(self, **kwargs):
134+
"""Custom init method to pass kwargs to the body."""
118135
super().__init__(**kwargs)
119136

120137
body_kwargs = kwargs.get("body", kwargs)
121138

122139
self.body = self.Body(**body_kwargs)
123140

124141

125-
class UserMessagePayload(InteractionPayloadBase):
126-
class Body(InteractionPayloadBase):
142+
class UserMessagePayload(PayloadAndBodyBase):
143+
"""User message payload. Handles the user message and injected parameters."""
144+
145+
class Body(PayloadAndBodyBase):
127146
user_message: str = Field(..., alias="userMessage")
128147
injected_parameters: dict = Field(
129148
default_factory=dict, alias="injectedParameters"
@@ -154,6 +173,7 @@ def add_defaults(cls, values):
154173
body: Body | None = Field(default=None)
155174

156175
def __init__(self, **kwargs):
176+
"""Custom init method to pass kwargs to the body."""
157177
super().__init__(**kwargs)
158178

159179
body_kwargs = kwargs.get("body", kwargs)
@@ -162,6 +182,8 @@ def __init__(self, **kwargs):
162182

163183

164184
class InteractionPayload(RootModel):
185+
"""Interaction payload. Handles the root payload for the interaction"""
186+
165187
root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
166188
..., discriminator="payload_type"
167189
)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ system_message: |
103103
<sql_query_generation_rules>
104104
<engine_specific_rules>
105105
{{ engine_specific_rules }}
106+
Rows returned will be automatically limited to {{ row_limit }}.
106107
</engine_specific_rules>
107108
108109
Your primary focus is on:

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ system_message: |
1414
1515
<engine_specific_rules>
1616
{{ engine_specific_rules }}
17+
Rows returned will be automatically limited to {{ row_limit }}.
1718
</engine_specific_rules>
1819
1920
<common_conversions>

0 commit comments

Comments
 (0)