1
1
from __future__ import annotations
2
2
3
3
import datetime
4
+ import threading
4
5
import typing as t
5
6
import unittest
6
7
from collections import Counter
7
- from contextlib import AbstractContextManager , nullcontext
8
+ from contextlib import nullcontext , contextmanager , AbstractContextManager
8
9
from itertools import chain
9
10
from pathlib import Path
10
11
from unittest .mock import patch
46
47
class ModelTest (unittest .TestCase ):
47
48
__test__ = False
48
49
50
+ CONCURRENT_RENDER_LOCK = threading .Lock ()
51
+
49
52
def __init__ (
50
53
self ,
51
54
body : t .Dict [str , t .Any ],
@@ -57,6 +60,7 @@ def __init__(
57
60
path : Path | None = None ,
58
61
preserve_fixtures : bool = False ,
59
62
default_catalog : str | None = None ,
63
+ concurrency : bool = False ,
60
64
) -> None :
61
65
"""ModelTest encapsulates a unit test for a model.
62
66
@@ -79,6 +83,7 @@ def __init__(
79
83
self .preserve_fixtures = preserve_fixtures
80
84
self .default_catalog = default_catalog
81
85
self .dialect = dialect
86
+ self .concurrency = concurrency
82
87
83
88
self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
84
89
self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -310,6 +315,7 @@ def create_test(
310
315
path : Path | None ,
311
316
preserve_fixtures : bool = False ,
312
317
default_catalog : str | None = None ,
318
+ concurrency : bool = False ,
313
319
) -> t .Optional [ModelTest ]:
314
320
"""Create a SqlModelTest or a PythonModelTest.
315
321
@@ -353,6 +359,7 @@ def create_test(
353
359
path ,
354
360
preserve_fixtures ,
355
361
default_catalog ,
362
+ concurrency ,
356
363
)
357
364
358
365
def __str__ (self ) -> str :
@@ -512,10 +519,34 @@ def _normalize_column_name(self, name: str) -> str:
512
519
513
520
return normalized_name
514
521
515
- def _execute (self , query : exp .Query ) -> pd .DataFrame :
522
+ @contextmanager
523
+ def _concurrent_render_context (self ) -> t .Iterator [None ]:
524
+ """
525
+ Context manager that ensures that the tests are executed safely in a concurrent environment.
526
+ This is needed in case `execution_time` is set, as we'd then have to:
527
+ - Freeze time through `time_machine` (not thread safe)
528
+ - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
529
+ """
530
+ import time_machine
531
+
532
+ lock_ctx : AbstractContextManager = (
533
+ self .CONCURRENT_RENDER_LOCK if self .concurrency else nullcontext ()
534
+ )
535
+ time_ctx : AbstractContextManager = nullcontext ()
536
+ dialect_patch_ctx : AbstractContextManager = nullcontext ()
537
+
538
+ if self ._execution_time :
539
+ time_ctx = time_machine .travel (self ._execution_time , tick = False )
540
+ dialect_patch_ctx = patch .dict (
541
+ self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms
542
+ )
543
+
544
+ with lock_ctx , time_ctx , dialect_patch_ctx :
545
+ yield
546
+
547
+ def _execute (self , query : exp .Query | str ) -> pd .DataFrame :
516
548
"""Executes the given query using the testing engine adapter and returns a DataFrame."""
517
- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
518
- return self .engine_adapter .fetchdf (query )
549
+ return self .engine_adapter .fetchdf (query )
519
550
520
551
def _create_df (
521
552
self ,
@@ -570,13 +601,25 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False)
570
601
for alias , cte in ctes .items ():
571
602
cte_query = cte_query .with_ (alias , cte .this , recursive = recursive )
572
603
573
- actual = self ._execute (cte_query )
604
+ with self ._concurrent_render_context ():
605
+ # Similar to the model's query, we render the CTE query under the locked context
606
+ # so that the execution (fetchdf) can continue concurrently between the threads
607
+ sql = cte_query .sql (
608
+ self ._test_adapter_dialect , pretty = self .engine_adapter ._pretty_sql
609
+ )
610
+
611
+ actual = self ._execute (sql )
574
612
expected = self ._create_df (values , columns = cte_query .named_selects , partial = partial )
575
613
576
614
self .assert_equal (expected , actual , sort = sort , partial = partial )
577
615
578
616
def runTest (self ) -> None :
579
- query = self ._render_model_query ()
617
+ with self ._concurrent_render_context ():
618
+ # Render the model's query and generate the SQL under the locked context so that
619
+ # execution (fetchdf) can continue concurrently between the threads
620
+ query = self ._render_model_query ()
621
+ sql = query .sql (self ._test_adapter_dialect , pretty = self .engine_adapter ._pretty_sql )
622
+
580
623
with_clause = query .args .get ("with" )
581
624
582
625
if with_clause :
@@ -593,7 +636,7 @@ def runTest(self) -> None:
593
636
partial = values .get ("partial" )
594
637
sort = query .args .get ("order" ) is None
595
638
596
- actual = self ._execute (query )
639
+ actual = self ._execute (sql )
597
640
expected = self ._create_df (values , columns = self .model .columns_to_types , partial = partial )
598
641
599
642
self .assert_equal (expected , actual , sort = sort , partial = partial )
@@ -626,6 +669,7 @@ def __init__(
626
669
path : Path | None = None ,
627
670
preserve_fixtures : bool = False ,
628
671
default_catalog : str | None = None ,
672
+ concurrency : bool = False ,
629
673
) -> None :
630
674
"""PythonModelTest encapsulates a unit test for a Python model.
631
675
@@ -651,6 +695,7 @@ def __init__(
651
695
path ,
652
696
preserve_fixtures ,
653
697
default_catalog ,
698
+ concurrency ,
654
699
)
655
700
656
701
self .context = TestExecutionContext (
@@ -674,22 +719,13 @@ def runTest(self) -> None:
674
719
675
720
def _execute_model (self ) -> pd .DataFrame :
676
721
"""Executes the python model and returns a DataFrame."""
677
- if self ._execution_time :
678
- import time_machine
679
-
680
- time_ctx : AbstractContextManager = time_machine .travel (self ._execution_time , tick = False )
681
- else :
682
- time_ctx = nullcontext ()
722
+ with self ._concurrent_render_context ():
723
+ variables = self .body .get ("vars" , {}).copy ()
724
+ time_kwargs = {key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables }
725
+ df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
683
726
684
- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
685
- with time_ctx :
686
- variables = self .body .get ("vars" , {}).copy ()
687
- time_kwargs = {
688
- key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables
689
- }
690
- df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
691
- assert not isinstance (df , exp .Expression )
692
- return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
727
+ assert not isinstance (df , exp .Expression )
728
+ return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
693
729
694
730
695
731
def generate_test (
0 commit comments