diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index c8cfbda03c..51504ed4f2 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -453,7 +453,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator: if not self._snapshot_evaluator: self._snapshot_evaluator = SnapshotEvaluator( { - gateway: adapter.with_settings(log_level=logging.INFO) + gateway: adapter.with_settings(execute_log_level=logging.INFO) for gateway, adapter in self.engine_adapters.items() }, ddl_concurrent_tasks=self.concurrent_tasks, @@ -520,7 +520,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: return model - def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: + def scheduler( + self, + environment: t.Optional[str] = None, + snapshot_evaluator: t.Optional[SnapshotEvaluator] = None, + ) -> Scheduler: """Returns the built-in scheduler. Args: @@ -542,9 +546,11 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: if not snapshots: raise ConfigError("No models were found") - return self.create_scheduler(snapshots) + return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator) - def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler: + def create_scheduler( + self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator + ) -> Scheduler: """Creates the built-in scheduler. Args: @@ -555,7 +561,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler: """ return Scheduler( snapshots, - self.snapshot_evaluator, + snapshot_evaluator, self.state_sync, default_catalog=self.default_catalog, max_workers=self.concurrent_tasks, @@ -1931,7 +1937,7 @@ def _table_diff( ) return TableDiff( - adapter=adapter.with_settings(logger.getEffectiveLevel()), + adapter=adapter.with_settings(execute_log_level=logger.getEffectiveLevel()), source=source, target=target, on=on, diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 8740177837..1d34ff1401 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -147,19 +147,23 @@ def __init__( self._multithreaded = multithreaded self.correlation_id = correlation_id - def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter: + def with_settings(self, **kwargs: t.Any) -> EngineAdapter: + extra_kwargs = { + "null_connection": True, + "execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level), + **self._extra_config, + **kwargs, + } + adapter = self.__class__( self._connection_pool, dialect=self.dialect, sql_gen_kwargs=self._sql_gen_kwargs, default_catalog=self._default_catalog, - execute_log_level=log_level, register_comments=self._register_comments, - null_connection=True, multithreaded=self._multithreaded, pretty_sql=self._pretty_sql, - **self._extra_config, - **kwargs, + **extra_kwargs, ) return adapter diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 03f8bdcf71..545a5e5494 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -38,6 +38,7 @@ ) from sqlmesh.utils import to_snake_case from sqlmesh.core.state_sync import StateSync +from sqlmesh.utils import CorrelationId from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.errors import PlanError, SQLMeshError from sqlmesh.utils.dag import DAG @@ -71,7 +72,7 @@ def __init__( self, state_sync: StateSync, snapshot_evaluator: SnapshotEvaluator, - create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler], + create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler], default_catalog: t.Optional[str], console: t.Optional[Console] = None, ): @@ -88,6 +89,9 @@ def evaluate( circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: self._circuit_breaker = circuit_breaker + self.snapshot_evaluator = self.snapshot_evaluator.set_correlation_id( + CorrelationId.from_plan_id(plan.plan_id) + ) self.console.start_plan_evaluation(plan) analytics.collector.on_plan_apply_start( @@ -106,6 +110,7 @@ def evaluate( else: analytics.collector.on_plan_apply_end(plan_id=plan.plan_id) finally: + self.snapshot_evaluator.recycle() self.console.stop_plan_evaluation() def _evaluate_stages( @@ -228,7 +233,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla self.console.log_success("SKIP: No model batches to execute") return - scheduler = self.create_scheduler(stage.all_snapshots.values()) + scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator) errors, _ = scheduler.run_merged_intervals( merged_intervals=stage.snapshot_to_intervals, deployability_index=stage.deployability_index, @@ -249,7 +254,7 @@ def visit_audit_only_run_stage( return # If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them - scheduler = self.create_scheduler(audit_snapshots) + scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator) completion_status = scheduler.audit( plan.environment, plan.start, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index eff458dc5d..9a8aa2c7a5 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -61,7 +61,7 @@ SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.definition import parent_snapshots_by_name -from sqlmesh.utils import random_id +from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, concurrent_apply_to_values, @@ -127,6 +127,7 @@ def __init__( if not selected_gateway else self.adapters[selected_gateway] ) + self.selected_gateway = selected_gateway self.ddl_concurrent_tasks = ddl_concurrent_tasks def evaluate( @@ -1186,6 +1187,16 @@ def _execute_create( ) adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs)) + def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator: + return SnapshotEvaluator( + { + gateway: adapter.with_settings(correlation_id=correlation_id) + for gateway, adapter in self.adapters.items() + }, + self.ddl_concurrent_tasks, + self.selected_gateway, + ) + def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy: klass: t.Type diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 91221d73af..8923c4c75b 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -71,6 +71,7 @@ from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError from sqlmesh.utils.pydantic import validate_string from tests.conftest import DuckDBMetadata, SushiDataValidator +from sqlmesh.utils import CorrelationId from tests.utils.test_helpers import use_terminal_console from tests.utils.test_filesystem import create_temp_file @@ -6815,3 +6816,28 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call # valid_from should be the epoch, valid_to should be NaT assert str(row["valid_from"]) == "1970-01-01 00:00:00" assert pd.isna(row["valid_to"]) + + +def test_plan_evaluator_correlation_id(tmp_path: Path): + def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger): + sqls = [call[0][0] for call in mock_logger.call_args_list] + return any(f"/* {correlation_id} */" in sql for sql in sqls) + + ctx = Context(paths=[tmp_path], config=Config()) + + # Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan + for i in range(2): + create_temp_file( + tmp_path, + Path("models", "test.sql"), + f"MODEL (name test.a, kind FULL); SELECT {i} AS col", + ) + + with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: + ctx.load() + plan = ctx.plan(auto_apply=True, no_prompts=True) + + correlation_id = CorrelationId.from_plan_id(plan.plan_id) + assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}" + + assert _correlation_id_in_sqls(correlation_id, mock_logger) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 1b5c39e2dd..9ea0d64771 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -337,9 +337,9 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) # make with_settings() return the current instance of engine_adapter so we can still spy on _execute mocker.patch.object( - engine_adapter, "with_settings", new_callable=lambda: lambda _: engine_adapter + engine_adapter, "with_settings", new_callable=lambda: lambda **kwargs: engine_adapter ) - assert engine_adapter.with_settings(1) == engine_adapter + assert engine_adapter.with_settings() == engine_adapter spy_execute = mocker.spy(engine_adapter, "_execute") mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")