Skip to content

Commit fc2ed0d

Browse files
committed
Chore: Reintroduce tagging queries with correlation ID
1 parent f33beb6 commit fc2ed0d

File tree

5 files changed

+71
-22
lines changed

5 files changed

+71
-22
lines changed

sqlmesh/core/context.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -448,14 +448,8 @@ def engine_adapter(self) -> EngineAdapter:
448448
@property
449449
def snapshot_evaluator(self) -> SnapshotEvaluator:
450450
if not self._snapshot_evaluator:
451-
self._snapshot_evaluator = SnapshotEvaluator(
452-
{
453-
gateway: adapter.with_settings(log_level=logging.INFO)
454-
for gateway, adapter in self.engine_adapters.items()
455-
},
456-
ddl_concurrent_tasks=self.concurrent_tasks,
457-
selected_gateway=self.selected_gateway,
458-
)
451+
self._snapshot_evaluator = self._create_snapshot_evaluator(log_level=logging.INFO)
452+
459453
return self._snapshot_evaluator
460454

461455
def execution_context(
@@ -517,7 +511,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
517511

518512
return model
519513

520-
def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
514+
def scheduler(
515+
self,
516+
environment: t.Optional[str] = None,
517+
snapshot_evaluator: t.Optional[SnapshotEvaluator] = None,
518+
) -> Scheduler:
521519
"""Returns the built-in scheduler.
522520
523521
Args:
@@ -539,9 +537,11 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
539537
if not snapshots:
540538
raise ConfigError("No models were found")
541539

542-
return self.create_scheduler(snapshots)
540+
return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator)
543541

544-
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
542+
def create_scheduler(
543+
self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator
544+
) -> Scheduler:
545545
"""Creates the built-in scheduler.
546546
547547
Args:
@@ -552,7 +552,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
552552
"""
553553
return Scheduler(
554554
snapshots,
555-
self.snapshot_evaluator,
555+
snapshot_evaluator,
556556
self.state_sync,
557557
default_catalog=self.default_catalog,
558558
max_workers=self.concurrent_tasks,
@@ -2960,6 +2960,16 @@ def load_model_tests(
29602960

29612961
return model_tests
29622962

2963+
def _create_snapshot_evaluator(self, **kwargs: t.Any) -> SnapshotEvaluator:
2964+
return SnapshotEvaluator(
2965+
{
2966+
gateway: adapter.with_settings(**kwargs)
2967+
for gateway, adapter in self.engine_adapters.items()
2968+
},
2969+
ddl_concurrent_tasks=self.concurrent_tasks,
2970+
selected_gateway=self.selected_gateway,
2971+
)
2972+
29632973

29642974
class Context(GenericContext[Config]):
29652975
CONFIG_TYPE = Config

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,15 @@ def __init__(
147147
self._multithreaded = multithreaded
148148
self.correlation_id = correlation_id
149149

150-
def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
150+
def with_settings(self, log_level: int = logging.DEBUG, **kwargs: t.Any) -> EngineAdapter:
151151
adapter = self.__class__(
152152
self._connection_pool,
153153
dialect=self.dialect,
154154
sql_gen_kwargs=self._sql_gen_kwargs,
155155
default_catalog=self._default_catalog,
156156
execute_log_level=log_level,
157157
register_comments=self._register_comments,
158-
null_connection=True,
158+
null_connection=self._extra_config.pop("null_connection", True),
159159
multithreaded=self._multithreaded,
160160
pretty_sql=self._pretty_sql,
161161
**self._extra_config,

sqlmesh/core/plan/evaluator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from sqlmesh.utils import to_snake_case
4040
from sqlmesh.core.state_sync import StateSync
41+
from sqlmesh.utils import CorrelationId
4142
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4243
from sqlmesh.utils.errors import PlanError, SQLMeshError
4344
from sqlmesh.utils.dag import DAG
@@ -71,7 +72,7 @@ def __init__(
7172
self,
7273
state_sync: StateSync,
7374
snapshot_evaluator: SnapshotEvaluator,
74-
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
75+
create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler],
7576
default_catalog: t.Optional[str],
7677
console: t.Optional[Console] = None,
7778
):
@@ -89,6 +90,7 @@ def evaluate(
8990
) -> None:
9091
self._circuit_breaker = circuit_breaker
9192

93+
self.set_correlation_id(CorrelationId.from_plan_id(plan.plan_id))
9294
self.console.start_plan_evaluation(plan)
9395
analytics.collector.on_plan_apply_start(
9496
plan=plan,
@@ -228,7 +230,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
228230
self.console.log_success("SKIP: No model batches to execute")
229231
return
230232

231-
scheduler = self.create_scheduler(stage.all_snapshots.values())
233+
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
232234
errors, _ = scheduler.run_merged_intervals(
233235
merged_intervals=stage.snapshot_to_intervals,
234236
deployability_index=stage.deployability_index,
@@ -249,7 +251,7 @@ def visit_audit_only_run_stage(
249251
return
250252

251253
# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
252-
scheduler = self.create_scheduler(audit_snapshots)
254+
scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator)
253255
completion_status = scheduler.audit(
254256
plan.environment,
255257
plan.start,
@@ -348,6 +350,13 @@ def visit_finalize_environment_stage(
348350
) -> None:
349351
self.state_sync.finalize(plan.environment)
350352

353+
def set_correlation_id(self, correlation_id: CorrelationId) -> None:
354+
for key, adapter in self.snapshot_evaluator.adapters.items():
355+
if correlation_id != adapter.correlation_id:
356+
self.snapshot_evaluator.adapters[key] = adapter.with_settings(
357+
correlation_id=correlation_id
358+
)
359+
351360
def _promote_snapshots(
352361
self,
353362
plan: EvaluatablePlan,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,7 @@ def __init__(
122122
self.adapters = (
123123
adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters}
124124
)
125-
self.adapter = (
126-
next(iter(self.adapters.values()))
127-
if not selected_gateway
128-
else self.adapters[selected_gateway]
129-
)
125+
self.selected_gateway = selected_gateway
130126
self.ddl_concurrent_tasks = ddl_concurrent_tasks
131127

132128
def evaluate(
@@ -603,6 +599,14 @@ def close(self) -> None:
603599
except Exception:
604600
logger.exception("Failed to close Snapshot Evaluator")
605601

602+
@property
603+
def adapter(self) -> EngineAdapter:
604+
return (
605+
next(iter(self.adapters.values()))
606+
if not self.selected_gateway
607+
else self.adapters[self.selected_gateway]
608+
)
609+
606610
def _evaluate_snapshot(
607611
self,
608612
snapshot: Snapshot,

tests/core/test_integration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
7272
from sqlmesh.utils.pydantic import validate_string
7373
from tests.conftest import DuckDBMetadata, SushiDataValidator
74+
from sqlmesh.utils import CorrelationId
7475
from tests.utils.test_helpers import use_terminal_console
7576
from tests.utils.test_filesystem import create_temp_file
7677

@@ -6815,3 +6816,28 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
68156816
# valid_from should be the epoch, valid_to should be NaT
68166817
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
68176818
assert pd.isna(row["valid_to"])
6819+
6820+
6821+
def test_plan_evaluator_correlation_id(tmp_path: Path):
6822+
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
6823+
sqls = [call[0][0] for call in mock_logger.call_args_list]
6824+
return any(f"/* {correlation_id} */" in sql for sql in sqls)
6825+
6826+
ctx = Context(paths=[tmp_path], config=Config())
6827+
6828+
# Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan
6829+
for i in range(2):
6830+
create_temp_file(
6831+
tmp_path,
6832+
Path("models", "test.sql"),
6833+
f"MODEL (name test.a, kind FULL); SELECT {i} AS col",
6834+
)
6835+
6836+
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6837+
ctx.load()
6838+
plan = ctx.plan(auto_apply=True, no_prompts=True)
6839+
6840+
correlation_id = CorrelationId.from_plan_id(plan.plan_id)
6841+
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"
6842+
6843+
assert _correlation_id_in_sqls(correlation_id, mock_logger)

0 commit comments

Comments
 (0)