Skip to content

Commit f3dd3d6

Browse files
authored
Fix: Avoid concurrent dialect patching in model testing (#4266)
1 parent 4e8228e commit f3dd3d6

File tree

4 files changed

+164
-26
lines changed

4 files changed

+164
-26
lines changed

sqlmesh/core/test/definition.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import datetime
4+
import threading
45
import typing as t
56
import unittest
67
from collections import Counter
7-
from contextlib import AbstractContextManager, nullcontext
8+
from contextlib import nullcontext, contextmanager, AbstractContextManager
89
from itertools import chain
910
from pathlib import Path
1011
from unittest.mock import patch
@@ -46,6 +47,8 @@
4647
class ModelTest(unittest.TestCase):
4748
__test__ = False
4849

50+
CONCURRENT_RENDER_LOCK = threading.Lock()
51+
4952
def __init__(
5053
self,
5154
body: t.Dict[str, t.Any],
@@ -57,6 +60,7 @@ def __init__(
5760
path: Path | None = None,
5861
preserve_fixtures: bool = False,
5962
default_catalog: str | None = None,
63+
concurrency: bool = False,
6064
) -> None:
6165
"""ModelTest encapsulates a unit test for a model.
6266
@@ -79,6 +83,7 @@ def __init__(
7983
self.preserve_fixtures = preserve_fixtures
8084
self.default_catalog = default_catalog
8185
self.dialect = dialect
86+
self.concurrency = concurrency
8287

8388
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
8489
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -310,6 +315,7 @@ def create_test(
310315
path: Path | None,
311316
preserve_fixtures: bool = False,
312317
default_catalog: str | None = None,
318+
concurrency: bool = False,
313319
) -> t.Optional[ModelTest]:
314320
"""Create a SqlModelTest or a PythonModelTest.
315321
@@ -353,6 +359,7 @@ def create_test(
353359
path,
354360
preserve_fixtures,
355361
default_catalog,
362+
concurrency,
356363
)
357364

358365
def __str__(self) -> str:
@@ -512,10 +519,34 @@ def _normalize_column_name(self, name: str) -> str:
512519

513520
return normalized_name
514521

515-
def _execute(self, query: exp.Query) -> pd.DataFrame:
522+
@contextmanager
523+
def _concurrent_render_context(self) -> t.Iterator[None]:
524+
"""
525+
Context manager that ensures that the tests are executed safely in a concurrent environment.
526+
This is needed in case `execution_time` is set, as we'd then have to:
527+
- Freeze time through `time_machine` (not thread safe)
528+
- Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
529+
"""
530+
import time_machine
531+
532+
lock_ctx: AbstractContextManager = (
533+
self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext()
534+
)
535+
time_ctx: AbstractContextManager = nullcontext()
536+
dialect_patch_ctx: AbstractContextManager = nullcontext()
537+
538+
if self._execution_time:
539+
time_ctx = time_machine.travel(self._execution_time, tick=False)
540+
dialect_patch_ctx = patch.dict(
541+
self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
542+
)
543+
544+
with lock_ctx, time_ctx, dialect_patch_ctx:
545+
yield
546+
547+
def _execute(self, query: exp.Query | str) -> pd.DataFrame:
516548
"""Executes the given query using the testing engine adapter and returns a DataFrame."""
517-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
518-
return self.engine_adapter.fetchdf(query)
549+
return self.engine_adapter.fetchdf(query)
519550

520551
def _create_df(
521552
self,
@@ -570,13 +601,25 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False)
570601
for alias, cte in ctes.items():
571602
cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
572603

573-
actual = self._execute(cte_query)
604+
with self._concurrent_render_context():
605+
# Similar to the model's query, we render the CTE query under the locked context
606+
# so that the execution (fetchdf) can continue concurrently between the threads
607+
sql = cte_query.sql(
608+
self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql
609+
)
610+
611+
actual = self._execute(sql)
574612
expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
575613

576614
self.assert_equal(expected, actual, sort=sort, partial=partial)
577615

578616
def runTest(self) -> None:
579-
query = self._render_model_query()
617+
with self._concurrent_render_context():
618+
# Render the model's query and generate the SQL under the locked context so that
619+
# execution (fetchdf) can continue concurrently between the threads
620+
query = self._render_model_query()
621+
sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
622+
580623
with_clause = query.args.get("with")
581624

582625
if with_clause:
@@ -593,7 +636,7 @@ def runTest(self) -> None:
593636
partial = values.get("partial")
594637
sort = query.args.get("order") is None
595638

596-
actual = self._execute(query)
639+
actual = self._execute(sql)
597640
expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
598641

599642
self.assert_equal(expected, actual, sort=sort, partial=partial)
@@ -626,6 +669,7 @@ def __init__(
626669
path: Path | None = None,
627670
preserve_fixtures: bool = False,
628671
default_catalog: str | None = None,
672+
concurrency: bool = False,
629673
) -> None:
630674
"""PythonModelTest encapsulates a unit test for a Python model.
631675
@@ -651,6 +695,7 @@ def __init__(
651695
path,
652696
preserve_fixtures,
653697
default_catalog,
698+
concurrency,
654699
)
655700

656701
self.context = TestExecutionContext(
@@ -674,22 +719,13 @@ def runTest(self) -> None:
674719

675720
def _execute_model(self) -> pd.DataFrame:
676721
"""Executes the python model and returns a DataFrame."""
677-
if self._execution_time:
678-
import time_machine
679-
680-
time_ctx: AbstractContextManager = time_machine.travel(self._execution_time, tick=False)
681-
else:
682-
time_ctx = nullcontext()
722+
with self._concurrent_render_context():
723+
variables = self.body.get("vars", {}).copy()
724+
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
725+
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
683726

684-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
685-
with time_ctx:
686-
variables = self.body.get("vars", {}).copy()
687-
time_kwargs = {
688-
key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables
689-
}
690-
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
691-
assert not isinstance(df, exp.Expression)
692-
return df if isinstance(df, pd.DataFrame) else df.toPandas()
727+
assert not isinstance(df, exp.Expression)
728+
return df if isinstance(df, pd.DataFrame) else df.toPandas()
693729

694730

695731
def generate_test(

sqlmesh/core/test/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def log_test_report(self, test_duration: float) -> None:
100100
for test_case, failure in failures:
101101
stream.writeln(unittest.TextTestResult.separator1)
102102
stream.writeln(f"FAIL: {test_case}")
103-
stream.writeln(f"{test_case.shortDescription()}")
103+
if test_description := test_case.shortDescription():
104+
stream.writeln(test_description)
104105
stream.writeln(unittest.TextTestResult.separator2)
105106
stream.writeln(failure)
106107

sqlmesh/core/test/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def run_tests(
120120
default_catalog_dialect=default_catalog_dialect,
121121
)
122122

123+
# Ensure workers are not greater than the number of tests
124+
num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
125+
123126
def _run_single_test(
124127
metadata: ModelTestMetadata, engine_adapter: EngineAdapter
125128
) -> t.Optional[ModelTextTestResult]:
@@ -132,6 +135,7 @@ def _run_single_test(
132135
path=metadata.path,
133136
default_catalog=default_catalog,
134137
preserve_fixtures=preserve_fixtures,
138+
concurrency=num_workers > 1,
135139
)
136140

137141
if not test:
@@ -159,9 +163,6 @@ def _run_single_test(
159163

160164
test_results = []
161165

162-
# Ensure workers are not greater than the number of tests
163-
num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
164-
165166
start_time = time.perf_counter()
166167
try:
167168
with ThreadPoolExecutor(max_workers=num_workers) as pool:

tests/core/test_test.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,3 +2370,103 @@ def test_number_of_tests_found(tmp_path: Path) -> None:
23702370
# Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest
23712371
results = context.test(tests=[f"{test_file}::test_example_full_model2"])
23722372
assert len(results.successes) == 1
2373+
2374+
2375+
def test_freeze_time_concurrent(tmp_path: Path) -> None:
2376+
tests_dir = tmp_path / "tests"
2377+
tests_dir.mkdir()
2378+
2379+
macros_dir = tmp_path / "macros"
2380+
macros_dir.mkdir()
2381+
2382+
macro_file = macros_dir / "test_datetime_now.py"
2383+
macro_file.write_text(
2384+
"""
2385+
from sqlglot import exp
2386+
import datetime
2387+
from sqlmesh.core.macros import macro
2388+
2389+
@macro()
2390+
def test_datetime_now(evaluator):
2391+
return exp.cast(exp.Literal.string(datetime.datetime.now(tz=datetime.timezone.utc)), exp.DataType.Type.DATE)
2392+
2393+
@macro()
2394+
def test_sqlglot_expr(evaluator):
2395+
return exp.CurrentDate().sql(evaluator.dialect)
2396+
"""
2397+
)
2398+
2399+
models_dir = tmp_path / "models"
2400+
models_dir.mkdir()
2401+
sql_model1 = models_dir / "sql_model1.sql"
2402+
sql_model1.write_text(
2403+
"""
2404+
MODEL(NAME sql_model1);
2405+
SELECT @test_datetime_now() AS col_exec_ds_time, @test_sqlglot_expr() AS col_current_date;
2406+
"""
2407+
)
2408+
2409+
for model_name in ["sql_model1", "sql_model2", "py_model"]:
2410+
for i in range(5):
2411+
test_2019 = tmp_path / "tests" / f"test_2019_{model_name}_{i}.yaml"
2412+
test_2019.write_text(
2413+
f"""
2414+
test_2019_{model_name}_{i}:
2415+
model: {model_name}
2416+
vars:
2417+
execution_time: '2019-12-01'
2418+
outputs:
2419+
query:
2420+
rows:
2421+
- col_exec_ds_time: '2019-12-01'
2422+
col_current_date: '2019-12-01'
2423+
"""
2424+
)
2425+
2426+
test_2025 = tmp_path / "tests" / f"test_2025_{model_name}_{i}.yaml"
2427+
test_2025.write_text(
2428+
f"""
2429+
test_2025_{model_name}_{i}:
2430+
model: {model_name}
2431+
vars:
2432+
execution_time: '2025-12-01'
2433+
outputs:
2434+
query:
2435+
rows:
2436+
- col_exec_ds_time: '2025-12-01'
2437+
col_current_date: '2025-12-01'
2438+
"""
2439+
)
2440+
2441+
ctx = Context(
2442+
paths=tmp_path,
2443+
config=Config(default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8)),
2444+
)
2445+
2446+
@model(
2447+
"py_model",
2448+
columns={"col_exec_ds_time": "timestamp_ntz", "col_current_date": "timestamp_ntz"},
2449+
)
2450+
def execute(context, start, end, execution_time, **kwargs):
2451+
datetime_now_utc = datetime.datetime.now(tz=datetime.timezone.utc)
2452+
2453+
context.engine_adapter.execute(exp.select("CURRENT_DATE()"))
2454+
current_date = context.engine_adapter.cursor.fetchone()[0]
2455+
2456+
return pd.DataFrame(
2457+
[{"col_exec_ds_time": datetime_now_utc, "col_current_date": current_date}]
2458+
)
2459+
2460+
python_model = model.get_registry()["py_model"].model(module_path=Path("."), path=Path("."))
2461+
ctx.upsert_model(python_model)
2462+
2463+
ctx.upsert_model(
2464+
_create_model(
2465+
meta="MODEL(NAME sql_model2)",
2466+
query="SELECT @execution_ds::timestamp_ntz AS col_exec_ds_time, current_date()::date AS col_current_date",
2467+
default_catalog=ctx.default_catalog,
2468+
)
2469+
)
2470+
2471+
results = ctx.test()
2472+
assert len(results.successes) == 30

0 commit comments

Comments
 (0)