Skip to content

Chore: Reintroduce tagging queries with correlation ID #4895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
self._snapshot_evaluator = SnapshotEvaluator(
{
gateway: adapter.with_settings(log_level=logging.INFO)
gateway: adapter.with_settings(execute_log_level=logging.INFO)
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
Expand Down Expand Up @@ -520,7 +520,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:

return model

def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
def scheduler(
self,
environment: t.Optional[str] = None,
snapshot_evaluator: t.Optional[SnapshotEvaluator] = None,
) -> Scheduler:
"""Returns the built-in scheduler.

Args:
Expand All @@ -542,9 +546,11 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
if not snapshots:
raise ConfigError("No models were found")

return self.create_scheduler(snapshots)
return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator)

def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
def create_scheduler(
self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator
) -> Scheduler:
"""Creates the built-in scheduler.

Args:
Expand All @@ -555,7 +561,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
"""
return Scheduler(
snapshots,
self.snapshot_evaluator,
snapshot_evaluator,
self.state_sync,
default_catalog=self.default_catalog,
max_workers=self.concurrent_tasks,
Expand Down Expand Up @@ -1931,7 +1937,7 @@ def _table_diff(
)

return TableDiff(
adapter=adapter.with_settings(logger.getEffectiveLevel()),
adapter=adapter.with_settings(execute_log_level=logger.getEffectiveLevel()),
source=source,
target=target,
on=on,
Expand Down
14 changes: 9 additions & 5 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,23 @@ def __init__(
self._multithreaded = multithreaded
self.correlation_id = correlation_id

def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
extra_kwargs = {
"null_connection": True,
"execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level),
**self._extra_config,
**kwargs,
}

adapter = self.__class__(
self._connection_pool,
dialect=self.dialect,
sql_gen_kwargs=self._sql_gen_kwargs,
default_catalog=self._default_catalog,
execute_log_level=log_level,
register_comments=self._register_comments,
null_connection=True,
multithreaded=self._multithreaded,
pretty_sql=self._pretty_sql,
**self._extra_config,
**kwargs,
**extra_kwargs,
)

return adapter
Expand Down
11 changes: 8 additions & 3 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from sqlmesh.utils import to_snake_case
from sqlmesh.core.state_sync import StateSync
from sqlmesh.utils import CorrelationId
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.errors import PlanError, SQLMeshError
from sqlmesh.utils.dag import DAG
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
self,
state_sync: StateSync,
snapshot_evaluator: SnapshotEvaluator,
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler],
default_catalog: t.Optional[str],
console: t.Optional[Console] = None,
):
Expand All @@ -88,6 +89,9 @@ def evaluate(
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
self._circuit_breaker = circuit_breaker
self.snapshot_evaluator = self.snapshot_evaluator.set_correlation_id(
CorrelationId.from_plan_id(plan.plan_id)
)

self.console.start_plan_evaluation(plan)
analytics.collector.on_plan_apply_start(
Expand All @@ -106,6 +110,7 @@ def evaluate(
else:
analytics.collector.on_plan_apply_end(plan_id=plan.plan_id)
finally:
self.snapshot_evaluator.recycle()
self.console.stop_plan_evaluation()

def _evaluate_stages(
Expand Down Expand Up @@ -228,7 +233,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
self.console.log_success("SKIP: No model batches to execute")
return

scheduler = self.create_scheduler(stage.all_snapshots.values())
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
errors, _ = scheduler.run_merged_intervals(
merged_intervals=stage.snapshot_to_intervals,
deployability_index=stage.deployability_index,
Expand All @@ -249,7 +254,7 @@ def visit_audit_only_run_stage(
return

# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
scheduler = self.create_scheduler(audit_snapshots)
scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator)
completion_status = scheduler.audit(
plan.environment,
plan.start,
Expand Down
13 changes: 12 additions & 1 deletion sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
SnapshotTableCleanupTask,
)
from sqlmesh.core.snapshot.definition import parent_snapshots_by_name
from sqlmesh.utils import random_id
from sqlmesh.utils import random_id, CorrelationId
from sqlmesh.utils.concurrency import (
concurrent_apply_to_snapshots,
concurrent_apply_to_values,
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
if not selected_gateway
else self.adapters[selected_gateway]
)
self.selected_gateway = selected_gateway
self.ddl_concurrent_tasks = ddl_concurrent_tasks

def evaluate(
Expand Down Expand Up @@ -1186,6 +1187,16 @@ def _execute_create(
)
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))

def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator:
return SnapshotEvaluator(
{
gateway: adapter.with_settings(correlation_id=correlation_id)
for gateway, adapter in self.adapters.items()
},
self.ddl_concurrent_tasks,
self.selected_gateway,
)


def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
klass: t.Type
Expand Down
26 changes: 26 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
from sqlmesh.utils.pydantic import validate_string
from tests.conftest import DuckDBMetadata, SushiDataValidator
from sqlmesh.utils import CorrelationId
from tests.utils.test_helpers import use_terminal_console
from tests.utils.test_filesystem import create_temp_file

Expand Down Expand Up @@ -6815,3 +6816,28 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
# valid_from should be the epoch, valid_to should be NaT
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
assert pd.isna(row["valid_to"])


def test_plan_evaluator_correlation_id(tmp_path: Path):
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
sqls = [call[0][0] for call in mock_logger.call_args_list]
return any(f"/* {correlation_id} */" in sql for sql in sqls)

ctx = Context(paths=[tmp_path], config=Config())

# Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan
for i in range(2):
create_temp_file(
tmp_path,
Path("models", "test.sql"),
f"MODEL (name test.a, kind FULL); SELECT {i} AS col",
)

with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
ctx.load()
plan = ctx.plan(auto_apply=True, no_prompts=True)

correlation_id = CorrelationId.from_plan_id(plan.plan_id)
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"

assert _correlation_id_in_sqls(correlation_id, mock_logger)
4 changes: 2 additions & 2 deletions tests/core/test_table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,9 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)

# make with_settings() return the current instance of engine_adapter so we can still spy on _execute
mocker.patch.object(
engine_adapter, "with_settings", new_callable=lambda: lambda _: engine_adapter
engine_adapter, "with_settings", new_callable=lambda: lambda **kwargs: engine_adapter
)
assert engine_adapter.with_settings(1) == engine_adapter
assert engine_adapter.with_settings() == engine_adapter

spy_execute = mocker.spy(engine_adapter, "_execute")
mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")
Expand Down