Skip to content

Commit 298e175

Browse files
authored
feat: Add Spider dataset evaluation framework and SQLite support (#131)
* Add Spider evaluation implementation for text2sql solution - Add evaluation utilities for Spider dataset benchmarking - Implement SQLite connector for Spider database support - Update schema selection and query generation prompts - Add evaluation notebook with benchmarking results - Update dependencies in pyproject.toml files * feat: Improved SQL schema selection and SQLite connector for Spider evaluation * style: Fix trailing whitespace issues * style: Fix JSON formatting in Jupyter notebook * style: Apply black formatting to Python files * style: Apply Ruff fixes * docs: Update Spider dataset and test suite download instructions * style: Fix JSON formatting in notebook * refactor: improve SQL connectors and agents for spider evaluation - Update SQL connectors in text_2_sql_core - Enhance AutoGen agents for parallel query solving - Update schema selection agents - Format code with black and fix linting issues - Update dependencies
1 parent 408e30f commit 298e175

19 files changed

+1407
-535
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,10 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
164+
# Spider Test Suite
165+
# These directories contain large test databases and data that can be downloaded separately:
166+
# Spider test suite evaluation scripts: https://github.com/taoyds/test-suite-sql-eval
167+
# Spider data: https://drive.google.com/file/d/1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J/view
168+
/text_2_sql/test-suite-sql-eval/
169+
/text_2_sql/spider_data/

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"text-2-sql-core",
8+
"text-2-sql-core[sqlite]",
99
]
1010

1111
[dependency-groups]
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Evaluate AutoGenText2SQL\n",
8+
"\n",
9+
"This notebook evaluates the AutoGenText2Sql class using the Spider test suite evaluation metric. \n",
10+
"\n",
11+
"The evaluation uses the official Spider evaluation approach, which requires:\n",
12+
"\n",
13+
"1. A gold file with format: `SQL query \\t database_id`\n",
14+
"2. A predictions file with generated SQL queries\n",
15+
"3. The Spider databases and schema information\n",
16+
"\n",
17+
"### Required Data Downloads\n",
18+
"\n",
19+
"Before running this notebook, you need to download and set up two required directories:\n",
20+
"\n",
21+
"1. Spider Test Suite Evaluation Scripts:\n",
22+
" - Download from: https://github.com/taoyds/test-suite-sql-eval\n",
23+
" - Clone this repository into `/text_2_sql/test-suite-sql-eval/` directory:\n",
24+
" ```bash\n",
25+
" cd text_2_sql\n",
26+
" git clone https://github.com/taoyds/test-suite-sql-eval\n",
27+
" ```\n",
28+
"\n",
29+
"2. Spider Dataset:\n",
30+
" - Download from: https://drive.google.com/file/d/1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J/view\n",
31+
" - Extract the downloaded file into `/text_2_sql/spider_data/` directory\n",
32+
" - The directory should contain:\n",
33+
" - `database/` directory with all the SQLite databases\n",
34+
" - `tables.json` with schema information\n",
35+
" - `dev.json` with development set queries"
36+
]
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"metadata": {},
41+
"source": [
42+
"### Dependencies\n",
43+
"\n",
44+
"To install dependencies for this evaluation:\n",
45+
"\n",
46+
"`uv sync --package autogen_text_2_sql`\n",
47+
"\n",
48+
"`uv add --editable text_2_sql_core`"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"import sys\n",
58+
"import os\n",
59+
"import time\n",
60+
"import json\n",
61+
"import logging\n",
62+
"import subprocess\n",
63+
"import dotenv\n",
64+
"from pathlib import Path\n",
65+
"\n",
66+
"# Get the notebook directory path\n",
67+
"notebook_dir = Path().absolute()\n",
68+
"# Add the src directory to the path\n",
69+
"sys.path.append(str(notebook_dir / \"src\"))\n",
70+
"\n",
71+
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n",
72+
"from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n",
73+
"\n",
74+
"# Configure logging\n",
75+
"logging.basicConfig(level=logging.DEBUG)\n",
76+
"logger = logging.getLogger(__name__)\n",
77+
"\n",
78+
"# Set up paths\n",
79+
"TEST_SUITE_DIR = Path(\"../test-suite-sql-eval\")\n",
80+
"SPIDER_DATA_DIR = Path(\"../spider_data\").absolute()\n",
81+
"DATABASE_DIR = SPIDER_DATA_DIR / \"database\"\n",
82+
"\n",
83+
"# Set SPIDER_DATA_DIR in environment so SQLiteSqlConnector can find tables.json\n",
84+
"os.environ[\"SPIDER_DATA_DIR\"] = str(SPIDER_DATA_DIR)\n",
85+
"\n",
86+
"# Load environment variables\n",
87+
"dotenv.load_dotenv()"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 2,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"# Initialize the AutoGenText2Sql instance with SQLite-specific rules\n",
97+
"sqlite_rules = \"\"\"\n",
98+
"1. Use SQLite syntax\n",
99+
"2. Do not use Azure SQL specific functions\n",
100+
"3. Use strftime for date/time operations\n",
101+
"\"\"\"\n",
102+
"\n",
103+
"autogen_text2sql = AutoGenText2Sql(\n",
104+
" engine_specific_rules=sqlite_rules,\n",
105+
" use_case=\"Evaluating with Spider SQLite databases\"\n",
106+
")"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 3,
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"# Function to generate SQL for a given question\n",
116+
"async def generate_sql(question):\n",
117+
" # Capture log output\n",
118+
" import io\n",
119+
" log_capture = io.StringIO()\n",
120+
" handler = logging.StreamHandler(log_capture)\n",
121+
" logger.addHandler(handler)\n",
122+
" \n",
123+
" logger.info(f\"Processing question: {question}\")\n",
124+
" logger.info(f\"Chat history: None\")\n",
125+
" \n",
126+
" # Track all SQL queries found\n",
127+
" all_queries = []\n",
128+
" final_query = None\n",
129+
" \n",
130+
" async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n",
131+
" if message.payload_type == \"answer_with_sources\":\n",
132+
" # Extract from results\n",
133+
" if hasattr(message.body, 'results'):\n",
134+
" for q_results in message.body.results.values():\n",
135+
" for result in q_results:\n",
136+
" if isinstance(result, dict) and 'sql_query' in result:\n",
137+
" sql_query = result['sql_query'].strip()\n",
138+
" if sql_query and sql_query != \"SELECT NULL -- No query found\":\n",
139+
" all_queries.append(sql_query)\n",
140+
" logger.info(f\"Found SQL query in results: {sql_query}\")\n",
141+
" \n",
142+
" # Extract from sources\n",
143+
" if hasattr(message.body, 'sources'):\n",
144+
" for source in message.body.sources:\n",
145+
" if hasattr(source, 'sql_query'):\n",
146+
" sql_query = source.sql_query.strip()\n",
147+
" if sql_query and sql_query != \"SELECT NULL -- No query found\":\n",
148+
" all_queries.append(sql_query)\n",
149+
" logger.info(f\"Found SQL query in sources: {sql_query}\")\n",
150+
" \n",
151+
" # Get the log text\n",
152+
" log_text = log_capture.getvalue()\n",
153+
" \n",
154+
" # Clean up logging\n",
155+
" logger.removeHandler(handler)\n",
156+
" log_capture.close()\n",
157+
" \n",
158+
" # Log all queries found\n",
159+
" if all_queries:\n",
160+
" logger.info(f\"All queries found: {all_queries}\")\n",
161+
" # Select the most appropriate query - prefer DISTINCT queries for questions about unique values\n",
162+
" question_lower = question.lower()\n",
163+
" needs_distinct = any(word in question_lower for word in ['different', 'distinct', 'unique', 'all'])\n",
164+
" \n",
165+
" for query in reversed(all_queries): # Look at queries in reverse order\n",
166+
" if needs_distinct and 'DISTINCT' in query.upper():\n",
167+
" final_query = query\n",
168+
" break\n",
169+
" if not final_query: # If no DISTINCT query found when needed, use the last query\n",
170+
" final_query = all_queries[-1]\n",
171+
" # Add DISTINCT if needed but not present\n",
172+
" if needs_distinct and 'DISTINCT' not in final_query.upper() and final_query.upper().startswith('SELECT '):\n",
173+
" final_query = final_query.replace('SELECT ', 'SELECT DISTINCT ', 1)\n",
174+
" \n",
175+
" # Log final query\n",
176+
" logger.info(f\"Final SQL query: {final_query or 'SELECT NULL -- No query found'}\")\n",
177+
" \n",
178+
" return final_query or \"SELECT NULL -- No query found\""
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": 4,
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"# Function to read Spider dev set and generate predictions\n",
188+
"async def generate_predictions(num_samples=None):\n",
189+
" # Read Spider dev set\n",
190+
" dev_file = SPIDER_DATA_DIR / \"dev.json\"\n",
191+
" pred_file = TEST_SUITE_DIR / \"predictions.txt\"\n",
192+
" gold_file = TEST_SUITE_DIR / \"gold.txt\"\n",
193+
" \n",
194+
" print(f\"Reading dev queries from {dev_file}\")\n",
195+
" with open(dev_file) as f:\n",
196+
" dev_data = json.load(f)\n",
197+
" \n",
198+
" # Limit number of samples if specified\n",
199+
" if num_samples is not None:\n",
200+
" dev_data = dev_data[:num_samples]\n",
201+
" print(f\"\\nGenerating predictions for {num_samples} queries...\")\n",
202+
" else:\n",
203+
" print(f\"\\nGenerating predictions for all {len(dev_data)} queries...\")\n",
204+
" \n",
205+
" predictions = []\n",
206+
" gold = []\n",
207+
" \n",
208+
" for idx, item in enumerate(dev_data, 1):\n",
209+
" question = item['question']\n",
210+
" db_id = item['db_id']\n",
211+
" gold_query = item['query']\n",
212+
" \n",
213+
" print(f\"\\nProcessing query {idx}/{len(dev_data)} for database {db_id}\")\n",
214+
" print(f\"Question: {question}\")\n",
215+
" \n",
216+
" # Update database connection string for current database\n",
217+
" db_path = DATABASE_DIR / db_id / f\"{db_id}.sqlite\"\n",
218+
" os.environ[\"Text2Sql__DatabaseConnectionString\"] = str(db_path)\n",
219+
" os.environ[\"Text2Sql__DatabaseName\"] = db_id\n",
220+
" \n",
221+
" sql = await generate_sql(question)\n",
222+
" predictions.append(f\"{sql}\\t{db_id}\")\n",
223+
" gold.append(f\"{gold_query}\\t{db_id}\")\n",
224+
" print(f\"Generated SQL: {sql}\")\n",
225+
" \n",
226+
" print(f\"\\nSaving predictions to {pred_file}\")\n",
227+
" with open(pred_file, 'w') as f:\n",
228+
" f.write('\\n'.join(predictions))\n",
229+
" \n",
230+
" print(f\"Saving gold queries to {gold_file}\")\n",
231+
" with open(gold_file, 'w') as f:\n",
232+
" f.write('\\n'.join(gold))\n",
233+
" \n",
234+
" return pred_file, gold_file"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": 5,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"# Run evaluation using the test suite evaluation script\n",
244+
"def run_evaluation():\n",
245+
" # Use absolute paths to ensure correct file locations\n",
246+
" gold_file = TEST_SUITE_DIR / \"gold.txt\"\n",
247+
" pred_file = TEST_SUITE_DIR / \"predictions.txt\"\n",
248+
" table_file = SPIDER_DATA_DIR / \"tables.json\" # Use Spider's schema file\n",
249+
" \n",
250+
" print(f\"Starting evaluation at {time.strftime('%H:%M:%S')}\")\n",
251+
" start_time = time.time()\n",
252+
" \n",
253+
" cmd = [\n",
254+
" \"python\",\n",
255+
" str(TEST_SUITE_DIR / \"evaluation.py\"),\n",
256+
" \"--gold\", str(gold_file),\n",
257+
" \"--pred\", str(pred_file),\n",
258+
" \"--db\", str(DATABASE_DIR),\n",
259+
" \"--table\", str(table_file),\n",
260+
" \"--etype\", \"all\",\n",
261+
" \"--plug_value\",\n",
262+
" \"--progress_bar_for_each_datapoint\" # Show progress for each test input\n",
263+
" ]\n",
264+
" \n",
265+
" result = subprocess.run(cmd, capture_output=True, text=True)\n",
266+
" \n",
267+
" end_time = time.time()\n",
268+
" duration = end_time - start_time\n",
269+
" \n",
270+
" print(\"\\nEvaluation Results:\")\n",
271+
" print(\"==================\")\n",
272+
" print(result.stdout)\n",
273+
" \n",
274+
" print(f\"\\nEvaluation completed in {duration:.2f} seconds\")\n",
275+
" print(f\"End time: {time.strftime('%H:%M:%S')}\")"
276+
]
277+
},
278+
{
279+
"cell_type": "code",
280+
"execution_count": null,
281+
"metadata": {},
282+
"outputs": [],
283+
"source": [
284+
"# Generate predictions first - now with optional num_samples parameter\n",
285+
"await generate_predictions(num_samples=20) # Generate predictions for just 20 samples (takes about 4 minutes)"
286+
]
287+
},
288+
{
289+
"cell_type": "code",
290+
"execution_count": null,
291+
"metadata": {},
292+
"outputs": [],
293+
"source": [
294+
"# Run evaluation\n",
295+
"run_evaluation()"
296+
]
297+
}
298+
],
299+
"metadata": {
300+
"kernelspec": {
301+
"display_name": ".venv",
302+
"language": "python",
303+
"name": "python3"
304+
},
305+
"language_info": {
306+
"codemirror_mode": {
307+
"name": "ipython",
308+
"version": 3
309+
},
310+
"file_extension": ".py",
311+
"mimetype": "text/x-python",
312+
"name": "python",
313+
"nbconvert_exporter": "python",
314+
"pygments_lexer": "ipython3",
315+
"version": "3.12.7"
316+
}
317+
},
318+
"nbformat": 4,
319+
"nbformat_minor": 4
320+
}

text_2_sql/autogen/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ dependencies = [
1111
"autogen-ext[azure,openai]==0.4.0.dev11",
1212
"grpcio>=1.68.1",
1313
"pyyaml>=6.0.2",
14-
"text_2_sql_core",
14+
"text_2_sql_core[snowflake,databricks]",
15+
"sqlparse>=0.4.4",
16+
"nltk>=3.8.1",
1517
]
1618

1719
[dependency-groups]

0 commit comments

Comments
 (0)