From 6b6c40a2958e2cb31be65232d8904dbe3aea9134 Mon Sep 17 00:00:00 2001 From: Alexey Egorov <5102843+alexeyegorov@users.noreply.github.com> Date: Sun, 25 Jan 2026 14:34:35 +0100 Subject: [PATCH 1/3] feat: plan for implementing session mode --- PLAN_v2_session_mode.md | 556 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 556 insertions(+) create mode 100644 PLAN_v2_session_mode.md diff --git a/PLAN_v2_session_mode.md b/PLAN_v2_session_mode.md new file mode 100644 index 000000000..8894b0c29 --- /dev/null +++ b/PLAN_v2_session_mode.md @@ -0,0 +1,556 @@ +# Implementation Plan v2: Full Job Cluster Support for SQL + Python Models + +## Summary + +Enable complete dbt pipelines (SQL + Python models) to execute entirely on Databricks job clusters by: +1. Adding `method: session` for SQL models using the active SparkSession +2. Adding `session` submission method for Python models executing directly in the same session + +**Key Benefits:** +- ๐ŸŽฏ **Primary Goal Achieved**: Run full dbt pipeline (SQL + Python) on a single job cluster +- ๐Ÿ’ฐ **70%+ Cost Savings**: Job clusters cost ~7ยข less than SQL Warehouses on Azure +- ๐Ÿ“Š **Better Observability**: Entire dbt run tracked in Databricks Jobs UI (vs. no tracking for SQL models before) +- โšก **Faster Execution**: No job submission overhead for Python models +- โœ… **Correct Behavior**: Sequential execution respecting DAG dependencies (as intended) +- ๐Ÿ”ง **Simpler Setup**: No API authentication or permissions management needed + +## Background + +**Current State:** +- SQL models require DBSQL connector (SQL Warehouses or All-Purpose clusters) +- Python models can use job clusters via Jobs API (but spawn separate jobs) +- No way to run a mixed pipeline entirely within a single job cluster session + +**Goal:** +- Run complete dbt pipelines (SQL + Python) on a single job cluster +- All models execute in the same SparkSession +- Preserve all dbt-databricks features (Unity Catalog, Streaming Tables, etc.) + +## Architecture Overview + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Databricks Job Cluster โ”‚ +โ”‚ โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ +โ”‚ โ”‚ SparkSession โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ”‚ SQL Models โ”€โ”€โ”€โ”€โ”€โ”€โ–บ spark.sql() โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ”‚ Python Models โ”€โ”€โ”€โ–บ exec() with spark context โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ +โ”‚ โ”‚ +โ”‚ profiles.yml: method: session โ”‚ +โ”‚ model config: submission_method: session โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +## Implementation Approach + +### Phase 1: Add Session Connection Method (SQL Models) + +**File: [credentials.py](dbt/adapters/databricks/credentials.py)** + +1. Add `method` field to `DatabricksCredentials`: +```python +@dataclass +class DatabricksCredentials(Credentials): + # ... existing fields ... + method: Optional[str] = None # "session" or "dbsql" (default) +``` + +2. Update validation in `__post_init__`: +```python +# Auto-detect session mode +if self.method is None: + if os.getenv("DBT_DATABRICKS_SESSION_MODE", "").lower() == "true": + self.method = "session" + elif os.getenv("DATABRICKS_RUNTIME_VERSION") and not self.host: + self.method = "session" + else: + self.method = "dbsql" + +# Validate based on method +if self.method == "session": + self._validate_session_mode() +else: + self.validate_creds() +``` + +3. Add session validation: +```python +def _validate_session_mode(self) -> None: + try: + from pyspark.sql import SparkSession # noqa + except ImportError: + raise DbtRuntimeError("Session mode requires pyspark") + if self.schema is None: + raise DbtValidationError("Schema is required for session mode") +``` + +### Phase 2: Create Session Handle Module (SQL Execution) + +**New File: [session.py](dbt/adapters/databricks/session.py)** + +1. `SessionCursorWrapper` - Adapts DataFrame to cursor interface: +```python +class SessionCursorWrapper: + def __init__(self, spark: SparkSession): + self._spark = spark + self._df: Optional[DataFrame] = None + self._rows: Optional[list[Row]] = None + + def execute(self, sql: str, bindings=None) -> "SessionCursorWrapper": + cleaned_sql = sql.strip().rstrip(";") + if bindings: + cleaned_sql = cleaned_sql % tuple(bindings) + self._df = self._spark.sql(cleaned_sql) + return self + + def fetchall(self) -> list[tuple]: + if self._rows is None and self._df: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + @property + def description(self) -> list[tuple]: + if self._df is None: + return [] + return [(f.name, f.dataType.simpleString(), None, None, None, None, f.nullable) + for f in self._df.schema.fields] +``` + +2. `DatabricksSessionHandle` - Wraps SparkSession: +```python +class DatabricksSessionHandle: + def __init__(self, spark: SparkSession): + self._spark = spark + + @staticmethod + def create(catalog=None, schema=None, session_properties=None): + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + if catalog: + spark.catalog.setCurrentCatalog(catalog) + if schema: + spark.catalog.setCurrentDatabase(schema) + if session_properties: + for k, v in session_properties.items(): + spark.conf.set(k, v) + return DatabricksSessionHandle(spark) + + @property + def dbr_version(self) -> tuple[int, int]: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion" + ) + return SqlUtils.extract_dbr_version(version_str) + + def execute(self, sql, bindings=None) -> SessionCursorWrapper: + return SessionCursorWrapper(self._spark).execute(sql, bindings) +``` + +### Phase 3: Update Connection Manager + +**File: [connections.py](dbt/adapters/databricks/connections.py)** + +1. Modify `open()` to dispatch based on method: +```python +@classmethod +def open(cls, connection: Connection) -> Connection: + creds: DatabricksCredentials = connection.credentials + + if creds.method == "session": + return cls._open_session(connection, creds, databricks_connection) + + # Existing DBSQL logic... +``` + +2. Add `_open_session()` method for session-based connections. + +### Phase 4: Add Session Python Submission Helper (Python Models) + +**File: [python_submissions.py](dbt/adapters/databricks/python_models/python_submissions.py)** + +1. Add new `SessionPythonSubmitter`: +```python +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: SparkSession): + self._spark = spark + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + # Create execution context with spark available + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec_globals = { + "spark": self._spark, + "dbt": __import__("dbt"), + } + + # Execute the compiled Python model code + exec(compiled_code, exec_globals) +``` + +2. Add session state cleanup utilities: +```python +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: SparkSession) -> None: + """Drop temporary views created during model execution.""" + # Get list of temp views + temp_views = [row.name for row in spark.sql("SHOW VIEWS").collect() + if row.isTemporary] + for view in temp_views: + spark.catalog.dropTempView(view) + + @staticmethod + def get_clean_exec_globals(spark: SparkSession) -> dict: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Add other safe imports as needed + } +``` + +3. Update `SessionPythonSubmitter` with error handling: +```python +class SessionPythonSubmitter(PythonSubmitter): + def __init__(self, spark: SparkSession): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) + + # Execute the compiled Python model code + exec(compiled_code, exec_globals) + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") +``` + +4. Add new `SessionPythonJobHelper`: +```python +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode.""" + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly + from pyspark.sql import SparkSession + self._spark = SparkSession.builder.getOrCreate() + + def submit(self, compiled_code: str) -> None: + submitter = SessionPythonSubmitter(self._spark) + submitter.submit(compiled_code) +``` + +**File: [impl.py](dbt/adapters/databricks/impl.py)** + +3. Register the new submission helper: +```python +@property +def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: + return { + "job_cluster": JobClusterPythonJobHelper, + "all_purpose_cluster": AllPurposeClusterPythonJobHelper, + "serverless_cluster": ServerlessClusterPythonJobHelper, + "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, # NEW + } +``` + +4. Update `submit_python_job()` to auto-select session mode: +```python +def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterResponse: + # Auto-select session submission when in session mode + creds = self.config.credentials + if creds.method == "session": + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + + # ... existing code ... + return super().submit_python_job(parsed_model, compiled_code) +``` + +### Phase 5: Testing + +**New Files:** +- `tests/unit/test_session.py` - Unit tests for session handle and cursor +- `tests/unit/test_session_python.py` - Unit tests for session Python submission +- `tests/functional/adapter/session/test_session_mixed_pipeline.py` - Integration tests + +## Files to Modify + +| File | Changes | +|------|---------| +| [credentials.py](dbt/adapters/databricks/credentials.py) | Add `method` field, auto-detection, validation | +| [connections.py](dbt/adapters/databricks/connections.py) | Add session dispatch in `open()`, capabilities caching | +| [impl.py](dbt/adapters/databricks/impl.py) | Register `SessionPythonJobHelper`, auto-select session mode | +| [python_submissions.py](dbt/adapters/databricks/python_models/python_submissions.py) | Add `SessionPythonSubmitter`, `SessionPythonJobHelper` | +| **NEW** [session.py](dbt/adapters/databricks/session.py) | `SessionCursorWrapper`, `DatabricksSessionHandle` | +| **NEW** [tests/unit/test_session.py](tests/unit/test_session.py) | Unit tests | + +## Configuration + +**profiles.yml for full session mode:** +```yaml +my_project: + target: job_cluster + outputs: + job_cluster: + type: databricks + method: session + catalog: main + schema: my_schema + # host/http_path/token NOT required +``` + +**Model-level override (optional):** +```sql +-- For Python models that need specific submission method +{{ config(submission_method='session') }} +``` + +**Environment variable:** +```bash +export DBT_DATABRICKS_SESSION_MODE=true +``` + +## Usage Scenarios + +### Scenario 1: Run dbt from Databricks Notebook on Job Cluster +```python +# In a Databricks notebook task within a job +from dbt.cli.main import dbtRunner + +dbt = dbtRunner() +result = dbt.invoke(["run"]) # Uses session mode automatically +``` + +### Scenario 2: Run dbt from Python Script Task +```python +# python_task.py - executed as a Python script task in a Databricks job +import subprocess +subprocess.run(["dbt", "run", "--profiles-dir", "/dbfs/path/to/profiles"]) +``` + +### Scenario 3: Databricks Workflow with dbt Task +```yaml +# Databricks job definition +tasks: + - task_key: run_dbt + job_cluster_key: my_cluster + python_wheel_task: + package_name: my_dbt_project + entry_point: run_dbt + libraries: + - pypi: {package: dbt-databricks} +``` + +## Verification Plan + +1. **Unit tests**: `pytest tests/unit/test_session*.py` +2. **Local validation**: Test credentials validation for both modes +3. **Integration test**: + ```python + # Run in Databricks notebook on job cluster + from dbt.cli.main import dbtRunner + + dbt = dbtRunner() + + # Test SQL model + result = dbt.invoke(["run", "--select", "my_sql_model"]) + assert result.success + + # Test Python model + result = dbt.invoke(["run", "--select", "my_python_model"]) + assert result.success + + # Test full pipeline + result = dbt.invoke(["run"]) + assert result.success + ``` + +## Session vs API Client Comparison + +### Data Collection (Python Models) + +**Both approaches write directly to tables without collecting to driver:** + +1. **API Client (Jobs/Command API)**: + - Executes: `df.write.saveAsTable("table_name")` + - Runs in separate context/notebook + - Data written directly to storage + +2. **Session Mode (exec())**: + - Executes: `df.write.saveAsTable("table_name")` (same code!) + - Runs in same SparkSession + - Data written directly to storage + +**Result**: No difference in data collection behavior - both are distributed writes. + +### Functional Differences + +| Feature | API Client | Session Mode | +|---------|-----------|--------------| +| **Execution Location** | Separate job/context | Same SparkSession | +| **Isolation** | Isolated execution | Shared session state | +| **Retry Logic** | Built-in via Jobs API | Manual if needed | +| **Monitoring** | Databricks job UI | SparkUI only | +| **Permissions** | Requires cluster access permissions | Uses current session | +| **Async Execution** | Submits and polls | Blocks until complete | +| **Library Installation** | Via cluster config or job spec | Must be pre-installed | +| **Resource Limits** | Can specify per-job | Uses session limits | + +### Session Mode vs API Client Trade-offs + +**IMPORTANT CONTEXT**: These differences only apply to **Python model execution**. SQL models have never had these API client features. Session mode actually **improves** observability by allowing the entire dbt pipeline to run within a Databricks job. + +| Aspect | SQL Models (Before) | Session Mode (SQL + Python) | API Client (Python only) | +|--------|---------------------|----------------------------|--------------------------| +| **Pipeline Observability** | โŒ No Databricks job tracking | โœ… **Full dbt run tracked in Databricks job** | โœ… Individual Python models tracked | +| **Cost** | Requires SQL Warehouse/All-Purpose | โœ… **Single job cluster** | Multiple job clusters | +| **Sequential Execution** | โœ… Follows DAG | โœ… **Follows DAG (intended)** | โœ… Follows DAG | +| **Per-Model Monitoring** | โŒ Not in Databricks UI | โŒ Not in Databricks UI | โœ… Each Python model visible | +| **Library Management** | Pre-installed | Pre-installed | โœ… Per-job dynamic install | +| **Failure Behavior** | Fails dbt run | Fails dbt run | โœ… Per-model retry | + +**Key Insight**: The "drawbacks" listed are actually **not regressions** - SQL models never had per-model Databricks job monitoring, dynamic libraries, or isolated retry logic. Session mode brings: + +โœ… **New Capability**: Run entire dbt pipeline on job cluster (previously impossible) +โœ… **Better Observability**: Whole dbt run visible in Databricks job (vs. no tracking before) +โœ… **Cost Savings**: Single cluster for SQL + Python (vs. SQL Warehouse + separate job clusters) +โœ… **Correct Behavior**: Sequential execution following data lineage (as intended) + +**The Only Real Trade-off**: Per-Python-model granularity in Databricks UI vs. whole-pipeline execution efficiency. + +### When to Use Each Approach + +**Use Session Mode (Recommended for most use cases):** +- โœ… Running dbt from Databricks job (notebooks, tasks, workflows) +- โœ… Cost optimization is priority (single cluster for entire pipeline) +- โœ… Want Databricks job-level tracking of dbt runs +- โœ… Sequential execution following DAG is acceptable (it should be!) +- โœ… Libraries can be installed on cluster beforehand + +**Use API Client (for specific advanced scenarios):** +- Need per-Python-model granular monitoring in Databricks UI +- Require per-model dynamic library installation +- Want independent retry logic for each Python model +- Need to run Python models on different cluster configurations +- Async submission of Python models (non-DAG execution) + +## Limitations + +**Actual Limitations:** +- **No API-based features**: Session mode doesn't have host/token, so API-dependent features (like standalone workflow job creation) won't work +- **No DBSQL query history**: Queries don't appear in Databricks SQL query history (but entire run appears in job logs) +- **Single cluster compute**: All models use same cluster resources (this is the goal for cost optimization) + +**Not Limitations (these are correct behaviors):** +- โœ… Python models don't appear as separate jobs in Databricks UI โ†’ **Correct**: They're part of the main dbt job +- โœ… Models execute sequentially following DAG โ†’ **Correct**: This respects data dependencies +- โœ… Libraries must be pre-installed โ†’ **Normal**: Same as SQL models; install on cluster startup + +## Design Decisions + +1. **Configuration**: Use `method: session` in profiles.yml (aligns with dbt-spark) +2. **Python execution**: Direct `exec()` in current SparkSession (simplest, most compatible) +3. **Auto-detection**: When `DATABRICKS_RUNTIME_VERSION` is set and no host configured, default to session mode +4. **Backward compatibility**: Default behavior unchanged when `method` not specified with host/token present + +## Migration Path + +For existing users: +1. No changes required if using DBSQL mode (default) +2. To use session mode, add `method: session` to profile +3. Python models automatically use session submission when profile is in session mode + +## FAQ + +### Q: When will Python model results be collected to the driver? + +**A: Never.** Both session mode and API client mode execute the same compiled code, which ends with: +```python +df.write.saveAsTable("table_name") +``` + +This writes the DataFrame directly to Delta/storage in a distributed manner. No `collect()` is called, so data stays distributed across the cluster. The only data transferred is metadata (schema, row counts, etc.). + +### Q: What are the drawbacks of not using the API client for Python models? + +**A: The "drawbacks" are actually not regressions - they're trade-offs that already existed for SQL models:** + +**What you "lose" (but SQL models never had):** +1. Per-Python-model visibility in Databricks Jobs UI (vs. per-notebook jobs) +2. Per-Python-model retry logic (vs. whole dbt run retry) +3. Dynamic library installation per model + +**What you GAIN with session mode:** +1. โœ… **Run entire pipeline on job cluster** (previously impossible) +2. โœ… **Databricks job tracking of full dbt run** (better than no tracking) +3. โœ… **70%+ cost savings** - Single cluster vs. SQL Warehouse + job clusters +4. โœ… **Faster execution** - No job submission overhead +5. โœ… **Correct DAG execution** - Sequential following lineage (as intended) +6. โœ… **Simpler setup** - No API authentication needed + +**Context**: SQL models have always executed sequentially, blocked downstream models, and lacked per-model Databricks UI tracking. Session mode brings Python models to the same (correct) behavior while enabling the entire pipeline to run on cost-effective job clusters. + +**Recommendation**: Use session mode for production pipelines running from Databricks jobs. It's more cost-effective, properly respects the DAG, and provides job-level observability. + +### Q: How do I handle failures in session mode? + +**A: Options:** + +1. **Databricks job-level retry**: Configure the Databricks job that runs dbt to retry on failure +2. **dbt retry logic**: Use `dbt retry` command after failures +3. **Model-level error handling**: Add try/except in Python model code if needed +4. **Checkpointing**: Use incremental models to avoid re-running successful work + +### Q: Can I mix session mode and API client mode? + +**A: Not recommended.** Stick to one approach per dbt run for consistency. However, you could: +- Use session mode for dev/testing +- Use API client mode for production +- Configure via environment-specific profiles + +## Future Considerations + +- Consider submitting as PR to dbt-databricks upstream +- May need periodic sync with dbt-databricks updates if maintained as fork +- Could add support for hybrid mode (session SQL + Jobs API Python) if needed +- Consider adding session-level Python model parallelization using ThreadPoolExecutor From c280488a97cf940871d6625fca9e21d7f86d439f Mon Sep 17 00:00:00 2001 From: Alexey Egorov <5102843+alexeyegorov@users.noreply.github.com> Date: Sun, 25 Jan 2026 18:21:35 +0100 Subject: [PATCH 2/3] feat: Implement session mode support for Databricks connections - Introduced `DatabricksSessionHandle` and `SessionCursorWrapper` to enable SparkSession-based execution. - Updated `DatabricksConnectionManager` to handle session mode connections and capabilities. - Enhanced `DatabricksCredentials` to auto-detect and validate connection methods. - Added session mode handling in Python model submission and execution. - Implemented cleanup for temporary views to prevent state leakage between models. This update allows dbt to run entirely within a single SparkSession on Databricks job clusters, improving execution efficiency and compatibility. --- dbt/adapters/databricks/connections.py | 100 ++++++- dbt/adapters/databricks/credentials.py | 93 +++++- dbt/adapters/databricks/impl.py | 12 + .../python_models/python_submissions.py | 113 ++++++- dbt/adapters/databricks/session.py | 282 ++++++++++++++++++ 5 files changed, 584 insertions(+), 16 deletions(-) create mode 100644 dbt/adapters/databricks/session.py diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 291ac7677..cbbd171d3 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -49,6 +49,7 @@ from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt.adapters.databricks.session import DatabricksSessionHandle, SessionCursorWrapper from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials if TYPE_CHECKING: @@ -150,6 +151,8 @@ class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_manager: Optional[DatabricksCredentialManager] = None _dbr_capabilities_cache: dict[str, DBRCapabilities] = {} + # Cache for session mode capabilities (keyed by "session") + _session_capabilities: Optional[DBRCapabilities] = None def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) @@ -159,9 +162,19 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): def api_client(self) -> DatabricksApiClient: if self._api_client is None: credentials = cast(DatabricksCredentials, self.profile.credentials) + if credentials.is_session_mode: + raise DbtRuntimeError( + "API client is not available in session mode. " + "Session mode does not support API-based operations." + ) self._api_client = DatabricksApiClient(credentials, 15 * 60) return self._api_client + def is_session_mode(self) -> bool: + """Check if the connection is using session mode.""" + credentials = cast(DatabricksCredentials, self.profile.credentials) + return credentials.is_session_mode + def is_cluster(self) -> bool: conn = self.get_thread_connection() databricks_conn = cast(DatabricksDBTConnection, conn) @@ -210,14 +223,16 @@ def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) - def cancel_open(self) -> list[str]: cancelled = super().cancel_open() - logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(self.api_client) + # Only cancel Python jobs via API if not in session mode + if not self.is_session_mode(): + logger.info("Cancelling open python jobs") + PythonRunTracker.cancel_runs(self.api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - handle: DatabricksHandle = self.get_thread_connection().handle + handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection().handle dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) @@ -321,7 +336,7 @@ def add_query( fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -337,7 +352,7 @@ def add_query( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = handle.execute(sql, bindings) response = self.get_response(cursor) fire_event( @@ -380,14 +395,18 @@ def execute( cursor.close() def _execute_with_cursor( - self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] + self, + log_sql: str, + f: Callable[ + [DatabricksHandle | DatabricksSessionHandle], CursorWrapper | SessionCursorWrapper + ], ) -> "Table": connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: fire_event( SQLQuery( @@ -399,7 +418,7 @@ def _execute_with_cursor( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = f(handle) response = self.get_response(cursor) @@ -464,9 +483,66 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials + + # Dispatch based on connection method + if creds.is_session_mode: + return cls._open_session(databricks_connection, creds) + else: + return cls._open_dbsql(databricks_connection, creds) + + @classmethod + def _open_session( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using SparkSession mode.""" + logger.debug("Opening connection in session mode") + + def connect() -> DatabricksSessionHandle: + try: + handle = DatabricksSessionHandle.create( + catalog=creds.database, + schema=creds.schema, + session_properties=creds.session_properties, + ) + databricks_connection.session_id = handle.session_id + + # Cache capabilities for session mode + cls._cache_session_capabilities(handle) + databricks_connection.capabilities = cls._session_capabilities or DBRCapabilities() + + logger.debug(f"Session mode connection opened: {handle}") + return handle + except Exception as exc: + logger.error(ConnectionCreateError(exc)) + raise DbtDatabaseError(f"Failed to create session connection: {exc}") from exc + + # Session mode doesn't need retry logic as SparkSession is already available + databricks_connection.handle = connect() + databricks_connection.state = ConnectionState.OPEN + return databricks_connection + + @classmethod + def _cache_session_capabilities(cls, handle: DatabricksSessionHandle) -> None: + """Cache DBR capabilities for session mode.""" + if cls._session_capabilities is None: + dbr_version = handle.dbr_version + cls._session_capabilities = DBRCapabilities( + dbr_version=dbr_version, + is_sql_warehouse=False, # Session mode is always on a cluster + ) + logger.debug(f"Cached session capabilities: DBR version {dbr_version}") + + @classmethod + def _open_dbsql( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using DBSQL connector.""" timeout = creds.connect_timeout - cls.credentials_manager = creds.authenticate() + credentials_manager = creds.authenticate() + # In DBSQL mode, authenticate() always returns a credentials manager + assert credentials_manager is not None, "Credentials manager is required for DBSQL mode" + cls.credentials_manager = credentials_manager # Get merged query tags if we have query header context query_header_context = getattr(databricks_connection, "_query_header_context", None) @@ -475,7 +551,7 @@ def open(cls, connection: Connection) -> Connection: merged_query_tags = QueryConfigUtils.get_merged_query_tags(query_header_context, creds) conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, databricks_connection.http_path, merged_query_tags + creds, credentials_manager, databricks_connection.http_path, merged_query_tags ) def connect() -> DatabricksHandle: @@ -507,7 +583,7 @@ def exponential_backoff(attempt: int) -> int: retryable_exceptions = [Error] return cls.retry_connection( - connection, + databricks_connection, connect=connect, logger=logger, retryable_exceptions=retryable_exceptions, @@ -527,7 +603,7 @@ def close(cls, connection: Connection) -> Connection: @classmethod def get_response(cls, cursor: Any) -> AdapterResponse: - if isinstance(cursor, CursorWrapper): + if isinstance(cursor, (CursorWrapper, SessionCursorWrapper)): return cursor.get_response() else: return AdapterResponse("OK") diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7e8af786b..d0a9745d1 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,12 +1,13 @@ import itertools import json +import os import re from collections.abc import Iterable from dataclasses import dataclass, field from typing import Any, Callable, Optional, cast from dbt.adapters.contracts.connection import Credentials -from dbt_common.exceptions import DbtConfigError, DbtValidationError +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError from mashumaro import DataClassDictMixin from requests import PreparedRequest from requests.auth import AuthBase @@ -16,6 +17,14 @@ from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger +# Connection method constants +CONNECTION_METHOD_SESSION = "session" +CONNECTION_METHOD_DBSQL = "dbsql" + +# Environment variable for session mode +DBT_DATABRICKS_SESSION_MODE_ENV = "DBT_DATABRICKS_SESSION_MODE" +DATABRICKS_RUNTIME_VERSION_ENV = "DATABRICKS_RUNTIME_VERSION" + CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") @@ -48,6 +57,9 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[dict[str, Any]] = None auth_type: Optional[str] = None + # Connection method: "session" for SparkSession mode, "dbsql" for DBSQL connector (default) + method: Optional[str] = None + # Named compute resources specified in the profile. Used for # creating a connection when a model specifies a compute resource. compute: Optional[dict[str, Any]] = None @@ -102,6 +114,9 @@ def __post_init__(self) -> None: else: self.database = "hive_metastore" + # Auto-detect and validate connection method + self._init_connection_method() + connection_parameters = self.connection_parameters or {} for key in ( "server_hostname", @@ -130,9 +145,57 @@ def __post_init__(self) -> None: if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters - self._credentials_manager = DatabricksCredentialManager.create_from(self) + + # Only create credentials manager for non-session mode + if not self.is_session_mode: + self._credentials_manager = DatabricksCredentialManager.create_from(self) + + def _init_connection_method(self) -> None: + """Initialize and validate the connection method.""" + if self.method is None: + # Auto-detect session mode + if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, "").lower() == "true": + self.method = CONNECTION_METHOD_SESSION + elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host: + # Running on Databricks cluster without host configured + self.method = CONNECTION_METHOD_SESSION + else: + self.method = CONNECTION_METHOD_DBSQL + + # Validate method value + if self.method not in (CONNECTION_METHOD_SESSION, CONNECTION_METHOD_DBSQL): + raise DbtValidationError( + f"Invalid connection method: '{self.method}'. " + f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'." + ) + + @property + def is_session_mode(self) -> bool: + """Check if using session mode (SparkSession) for connections.""" + return self.method == CONNECTION_METHOD_SESSION + + def _validate_session_mode(self) -> None: + """Validate configuration for session mode.""" + try: + from pyspark.sql import SparkSession # noqa: F401 + except ImportError: + raise DbtRuntimeError( + "Session mode requires PySpark. " + "Please ensure you are running on a Databricks cluster with PySpark available." + ) + + if self.schema is None: + raise DbtValidationError("Schema is required for session mode.") def validate_creds(self) -> None: + """Validate credentials based on connection method.""" + if self.is_session_mode: + self._validate_session_mode() + else: + self._validate_dbsql_creds() + + def _validate_dbsql_creds(self) -> None: + """Validate credentials for DBSQL connector mode.""" for key in ["host", "http_path"]: if not getattr(self, key): raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") @@ -197,6 +260,9 @@ def type(self) -> str: @property def unique_field(self) -> str: + if self.is_session_mode: + # For session mode, use a unique identifier based on catalog and schema + return f"session://{self.database}/{self.schema}" return cast(str, self.host) def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: @@ -209,6 +275,15 @@ def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, if key in as_dict: yield key, as_dict[key] + def _connection_keys_session(self) -> tuple[str, ...]: + """Connection keys for session mode.""" + connection_keys = ["method", "schema"] + if self.database: + connection_keys.insert(1, "catalog") + if self.session_properties: + connection_keys.append("session_properties") + return tuple(connection_keys) + def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: @@ -218,6 +293,11 @@ def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # # Thus, if `with_aliases` is `True`, `DatabricksCredentials._connection_keys` should return # the internal key names; otherwise it can use aliases to show in `dbt debug`. + + # Session mode has different connection keys + if self.is_session_mode: + return self._connection_keys_session() + connection_keys = ["host", "http_path", "schema"] if with_aliases: connection_keys.insert(2, "database") @@ -239,8 +319,15 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] - def authenticate(self) -> "DatabricksCredentialManager": + def authenticate(self) -> Optional["DatabricksCredentialManager"]: + """Authenticate and return credentials manager. + + For session mode, returns None as no external authentication is needed. + For DBSQL mode, validates credentials and returns the credentials manager. + """ self.validate_creds() + if self.is_session_mode: + return None assert self._credentials_manager is not None, "Credentials manager is not set." return self._credentials_manager diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 2190e1f9c..efbe51f55 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -50,13 +50,16 @@ DatabricksConnectionManager, DatabricksDBTConnection, ) +from dbt.adapters.databricks.credentials import DatabricksCredentials from dbt.adapters.databricks.dbr_capabilities import DBRCapability from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, ServerlessClusterPythonJobHelper, + SessionPythonJobHelper, WorkflowPythonJobHelper, ) from dbt.adapters.databricks.relation import ( @@ -801,6 +804,7 @@ def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, } @log_code_execution @@ -809,6 +813,14 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterRe "user_folder_for_python", self.behavior.use_user_folder_for_python.setting, # type: ignore[attr-defined] ) + + # Auto-select session submission when in session mode + creds = cast(DatabricksCredentials, self.config.credentials) + if creds.is_session_mode: + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + logger.debug("Auto-selected 'session' submission method for Python model") + return super().submit_python_job(parsed_model, compiled_code) @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index 23c2a27c1..a04ef8db8 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from dbt.adapters.base import PythonJobHelper from dbt_common.exceptions import DbtRuntimeError @@ -12,6 +12,9 @@ from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +if TYPE_CHECKING: + from pyspark.sql import SparkSession + DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -658,3 +661,111 @@ def build_submitter(self) -> PythonSubmitter: return PythonNotebookWorkflowSubmitter.create( self.api_client, self.tracker, self.parsed_model ) + + +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: "SparkSession") -> None: + """Drop temporary views created during model execution.""" + try: + # Get list of temp views from the current database + temp_views = [ + row.viewName + for row in spark.sql("SHOW VIEWS").collect() + if hasattr(row, "isTemporary") and row.isTemporary + ] + for view in temp_views: + try: + spark.catalog.dropTempView(view) + logger.debug(f"Dropped temp view: {view}") + except Exception as e: + logger.warning(f"Failed to drop temp view {view}: {e}") + except Exception as e: + logger.debug(f"Could not list temp views for cleanup: {e}") + + @staticmethod + def get_clean_exec_globals(spark: "SparkSession") -> dict[str, Any]: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Standard Python builtins are available by default + } + + +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) + + # Log a preview of the code being executed + preview_len = min(500, len(compiled_code)) + logger.debug( + f"[Session Python] Executing code preview: {compiled_code[:preview_len]}..." + ) + + # Execute the compiled Python model code + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec(compiled_code, exec_globals) + + logger.debug("[Session Python] Model execution completed successfully") + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") + + +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode. + + This helper executes Python models directly in the current SparkSession + without using the Databricks API. It's designed for running dbt on + job clusters where the SparkSession is already available. + """ + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly - no API client needed + from pyspark.sql import SparkSession + + self._spark = SparkSession.builder.getOrCreate() + logger.debug( + f"[Session Python] Using SparkSession: {self._spark.sparkContext.applicationId}" + ) + + self._submitter = SessionPythonSubmitter(self._spark) + + def submit(self, compiled_code: str) -> None: + """Submit the compiled Python model for execution.""" + self._submitter.submit(compiled_code) diff --git a/dbt/adapters/databricks/session.py b/dbt/adapters/databricks/session.py new file mode 100644 index 000000000..68df6b046 --- /dev/null +++ b/dbt/adapters/databricks/session.py @@ -0,0 +1,282 @@ +""" +Session mode support for dbt-databricks. + +This module provides SparkSession-based execution for running dbt on Databricks job clusters +without requiring the DBSQL connector. It enables complete dbt pipelines (SQL + Python models) +to execute entirely within a single SparkSession. + +Key components: +- SessionCursorWrapper: Adapts DataFrame results to the cursor interface expected by dbt +- DatabricksSessionHandle: Wraps SparkSession to provide the handle interface +""" + +import sys +from collections.abc import Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row, SparkSession + from pyspark.sql.types import StructField + + +class SessionCursorWrapper: + """ + Wraps SparkSession DataFrame results to provide a cursor-like interface. + + This adapter allows dbt to use SparkSession.sql() results in the same way + it uses DBSQL cursor results, maintaining compatibility with the existing + connection management code. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._df: Optional["DataFrame"] = None + self._rows: Optional[list["Row"]] = None + self._query_id: str = "session-query" + self.open = True + + def execute( + self, sql: str, bindings: Optional[Sequence[Any]] = None + ) -> "SessionCursorWrapper": + """Execute a SQL statement and store the resulting DataFrame.""" + cleaned_sql = SqlUtils.clean_sql(sql) + + # Handle bindings by simple string substitution if provided + if bindings: + translated = SqlUtils.translate_bindings(bindings) + if translated: + cleaned_sql = cleaned_sql % tuple(translated) + + logger.debug(f"Session mode executing SQL: {cleaned_sql[:200]}...") + self._df = self._spark.sql(cleaned_sql) + self._rows = None # Reset cached rows + return self + + def fetchall(self) -> Sequence[tuple]: + """Fetch all rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + def fetchone(self) -> Optional[tuple]: + """Fetch the next row from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if self._rows: + return tuple(self._rows.pop(0)) + return None + + def fetchmany(self, size: int) -> Sequence[tuple]: + """Fetch the next `size` rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if not self._rows: + return [] + result = [tuple(row) for row in self._rows[:size]] + self._rows = self._rows[size:] + return result + + @property + def description(self) -> Optional[list[tuple]]: + """Return column descriptions in DB-API format.""" + if self._df is None: + return None + return [self._field_to_description(f) for f in self._df.schema.fields] + + @staticmethod + def _field_to_description(field: "StructField") -> tuple: + """Convert a StructField to a DB-API description tuple.""" + # DB-API description: (name, type_code, display_size, internal_size, + # precision, scale, null_ok) + return ( + field.name, + field.dataType.simpleString(), + None, + None, + None, + None, + field.nullable, + ) + + def get_response(self) -> AdapterResponse: + """Return an adapter response for the executed query.""" + return AdapterResponse(_message="OK", query_id=self._query_id) + + def cancel(self) -> None: + """Cancel the current operation (no-op for session mode).""" + logger.debug("SessionCursorWrapper.cancel() called (no-op)") + self.open = False + + def close(self) -> None: + """Close the cursor.""" + logger.debug("SessionCursorWrapper.close() called") + self.open = False + self._df = None + self._rows = None + + def __str__(self) -> str: + return f"SessionCursor(query-id={self._query_id})" + + def __enter__(self) -> "SessionCursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksSessionHandle: + """ + Handle for a Databricks SparkSession. + + Provides the same interface as DatabricksHandle but uses the active SparkSession + instead of the DBSQL connector. This enables dbt to run on job clusters without + requiring external API connections. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self.open = True + self._cursor: Optional[SessionCursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + + @staticmethod + def create( + catalog: Optional[str] = None, + schema: Optional[str] = None, + session_properties: Optional[dict[str, Any]] = None, + ) -> "DatabricksSessionHandle": + """ + Create a DatabricksSessionHandle using the active SparkSession. + + Args: + catalog: Optional catalog to set as current + schema: Optional schema to set as current database + session_properties: Optional session configuration properties + + Returns: + A new DatabricksSessionHandle instance + """ + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + # Set catalog if provided + if catalog: + try: + spark.catalog.setCurrentCatalog(catalog) + logger.debug(f"Set current catalog to: {catalog}") + except Exception as e: + logger.warning(f"Failed to set catalog '{catalog}': {e}") + # Fall back to USE CATALOG for older Spark versions + spark.sql(f"USE CATALOG {catalog}") + + # Set schema/database if provided + if schema: + spark.catalog.setCurrentDatabase(schema) + logger.debug(f"Set current database to: {schema}") + + # Apply session properties + if session_properties: + for key, value in session_properties.items(): + spark.conf.set(key, str(value)) + logger.debug(f"Set session property {key}={value}") + + handle = DatabricksSessionHandle(spark) + logger.debug(f"Created session handle: {handle}") + return handle + + @property + def dbr_version(self) -> tuple[int, int]: + """Get the DBR version of the current cluster.""" + if self._dbr_version is None: + try: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion", "" + ) + if version_str: + self._dbr_version = SqlUtils.extract_dbr_version(version_str) + else: + # If we can't get the version, assume latest + logger.warning("Could not determine DBR version, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + except Exception as e: + logger.warning(f"Failed to get DBR version: {e}, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + return self._dbr_version + + @property + def session_id(self) -> str: + """Get a unique identifier for this session.""" + try: + # Try to get the Spark application ID as a session identifier + return self._spark.sparkContext.applicationId or "session-unknown" + except Exception: + return "session-unknown" + + def execute( + self, sql: str, bindings: Optional[Sequence[Any]] = None + ) -> SessionCursorWrapper: + """Execute a SQL statement and return a cursor wrapper.""" + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed session handle") + + if self._cursor: + self._cursor.close() + + self._cursor = SessionCursorWrapper(self._spark) + return self._cursor.execute(sql, bindings) + + def list_schemas( + self, database: str, schema: Optional[str] = None + ) -> SessionCursorWrapper: + """List schemas in the given database/catalog.""" + if schema: + sql = f"SHOW SCHEMAS IN {database} LIKE '{schema}'" + else: + sql = f"SHOW SCHEMAS IN {database}" + return self.execute(sql) + + def list_tables(self, database: str, schema: str) -> SessionCursorWrapper: + """List tables in the given database and schema.""" + sql = f"SHOW TABLES IN {database}.{schema}" + return self.execute(sql) + + def cancel(self) -> None: + """Cancel any in-progress operations.""" + logger.debug("DatabricksSessionHandle.cancel() called") + if self._cursor: + self._cursor.cancel() + self.open = False + + def close(self) -> None: + """Close the session handle.""" + logger.debug("DatabricksSessionHandle.close() called") + if self._cursor: + self._cursor.close() + self.open = False + # Note: We don't stop the SparkSession as it may be shared + + def rollback(self) -> None: + """Required for interface compatibility, but not implemented.""" + logger.debug("NotImplemented: rollback (session mode)") + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + self.close() + + def __str__(self) -> str: + return f"SessionHandle(session-id={self.session_id})" From 0a9ab92d02f4d01d827eddc1f1be4af500e5415f Mon Sep 17 00:00:00 2001 From: Alexey Egorov <5102843+alexeyegorov@users.noreply.github.com> Date: Sun, 25 Jan 2026 18:21:54 +0100 Subject: [PATCH 3/3] feat: Add unit tests for session mode functionality - Introduced comprehensive unit tests for session mode components, including `SessionCursorWrapper`, `DatabricksSessionHandle`, and session mode credentials. - Enhanced test coverage for session mode auto-detection and validation in `DatabricksCredentials`. - Implemented tests for session mode Python model submission and execution, ensuring proper handling of temporary views and execution errors. These additions improve the reliability and robustness of session mode features in the Databricks adapter. --- tests/unit/test_connection_manager.py | 1 + tests/unit/test_session.py | 323 +++++++++++++++++++++++++ tests/unit/test_session_credentials.py | 236 ++++++++++++++++++ tests/unit/test_session_python.py | 189 +++++++++++++++ 4 files changed, 749 insertions(+) create mode 100644 tests/unit/test_session.py create mode 100644 tests/unit/test_session_credentials.py create mode 100644 tests/unit/test_session_python.py diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py index 8cf9219e0..1a31b1cfb 100644 --- a/tests/unit/test_connection_manager.py +++ b/tests/unit/test_connection_manager.py @@ -69,6 +69,7 @@ def test_open_calls_is_cluster_http_path_for_warehouse( mock_connection.credentials.connect_retries = 1 mock_connection.credentials.connect_timeout = 10 mock_connection.credentials.query_tags = None + mock_connection.credentials.is_session_mode = False # Not session mode mock_connection.http_path = "sql/protocolv1/o/abc123def456" mock_connection.credentials.authenticate.return_value = Mock() mock_connection._query_header_context = None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..f1d947827 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,323 @@ +"""Unit tests for session mode components.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.databricks.session import ( + DatabricksSessionHandle, + SessionCursorWrapper, +) + + +class TestSessionCursorWrapper: + """Tests for SessionCursorWrapper.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + return spark + + @pytest.fixture + def cursor(self, mock_spark): + """Create a SessionCursorWrapper with mock SparkSession.""" + return SessionCursorWrapper(mock_spark) + + def test_execute_cleans_sql(self, cursor, mock_spark): + """Test that execute cleans SQL (strips whitespace and trailing semicolon).""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + result = cursor.execute(" SELECT 1; ") + + mock_spark.sql.assert_called_once_with("SELECT 1") + assert result is cursor + + def test_execute_with_bindings(self, cursor, mock_spark): + """Test that execute handles bindings via string substitution.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT %s, %s", (1, "test")) + + mock_spark.sql.assert_called_once_with("SELECT 1, test") + + def test_fetchall_returns_tuples(self, cursor, mock_spark): + """Test that fetchall returns list of tuples.""" + mock_df = MagicMock() + mock_row1 = MagicMock() + mock_row1.__iter__ = lambda self: iter([1, "a"]) + mock_row2 = MagicMock() + mock_row2.__iter__ = lambda self: iter([2, "b"]) + mock_df.collect.return_value = [mock_row1, mock_row2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + + assert result == [(1, "a"), (2, "b")] + + def test_fetchone_returns_single_tuple(self, cursor, mock_spark): + """Test that fetchone returns a single tuple.""" + mock_df = MagicMock() + mock_row = MagicMock() + mock_row.__iter__ = lambda self: iter([1, "a"]) + mock_df.collect.return_value = [mock_row] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result == (1, "a") + + def test_fetchone_returns_none_when_empty(self, cursor, mock_spark): + """Test that fetchone returns None when no rows.""" + mock_df = MagicMock() + mock_df.collect.return_value = [] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result is None + + def test_fetchmany_returns_limited_rows(self, cursor, mock_spark): + """Test that fetchmany returns limited number of rows.""" + mock_df = MagicMock() + mock_rows = [MagicMock() for _ in range(5)] + for i, row in enumerate(mock_rows): + row.__iter__ = (lambda i: lambda self: iter([i]))(i) + mock_df.collect.return_value = mock_rows + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchmany(2) + + assert len(result) == 2 + assert result == [(0,), (1,)] + + def test_description_returns_column_info(self, cursor, mock_spark): + """Test that description returns column metadata.""" + mock_df = MagicMock() + mock_field1 = MagicMock() + mock_field1.name = "id" + mock_field1.dataType.simpleString.return_value = "int" + mock_field1.nullable = False + + mock_field2 = MagicMock() + mock_field2.name = "name" + mock_field2.dataType.simpleString.return_value = "string" + mock_field2.nullable = True + + mock_df.schema.fields = [mock_field1, mock_field2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + desc = cursor.description + + assert len(desc) == 2 + assert desc[0][0] == "id" + assert desc[0][1] == "int" + assert desc[0][6] is False + assert desc[1][0] == "name" + assert desc[1][1] == "string" + assert desc[1][6] is True + + def test_description_returns_none_before_execute(self, cursor): + """Test that description returns None before execute.""" + assert cursor.description is None + + def test_get_response_returns_adapter_response(self, cursor, mock_spark): + """Test that get_response returns an AdapterResponse.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT 1") + response = cursor.get_response() + + assert response._message == "OK" + assert response.query_id == "session-query" + + def test_close_sets_open_to_false(self, cursor): + """Test that close sets open to False.""" + assert cursor.open is True + cursor.close() + assert cursor.open is False + + def test_context_manager(self, mock_spark): + """Test that cursor works as context manager.""" + with SessionCursorWrapper(mock_spark) as cursor: + assert cursor.open is True + assert cursor.open is False + + +class TestDatabricksSessionHandle: + """Tests for DatabricksSessionHandle.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + spark.sparkContext.applicationId = "app-123" + spark.conf.get.return_value = "14.3.x-scala2.12" + return spark + + @pytest.fixture + def handle(self, mock_spark): + """Create a DatabricksSessionHandle with mock SparkSession.""" + return DatabricksSessionHandle(mock_spark) + + def test_session_id_returns_application_id(self, handle): + """Test that session_id returns the Spark application ID.""" + assert handle.session_id == "app-123" + + def test_dbr_version_extracts_version(self, handle, mock_spark): + """Test that dbr_version extracts version from Spark config.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + version = handle.dbr_version + + assert version == (14, 3) + + def test_dbr_version_caches_result(self, handle, mock_spark): + """Test that dbr_version caches the result.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + _ = handle.dbr_version + _ = handle.dbr_version + + # Should only call conf.get once due to caching + assert mock_spark.conf.get.call_count == 1 + + def test_execute_returns_cursor(self, handle, mock_spark): + """Test that execute returns a SessionCursorWrapper.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor = handle.execute("SELECT 1") + + assert isinstance(cursor, SessionCursorWrapper) + + def test_execute_closes_previous_cursor(self, handle, mock_spark): + """Test that execute closes any previous cursor.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor1 = handle.execute("SELECT 1") + assert cursor1.open is True + + cursor2 = handle.execute("SELECT 2") + assert cursor1.open is False + assert cursor2.open is True + + def test_close_sets_open_to_false(self, handle): + """Test that close sets open to False.""" + assert handle.open is True + handle.close() + assert handle.open is False + + def test_list_schemas_executes_show_schemas(self, handle, mock_spark): + """Test that list_schemas executes SHOW SCHEMAS.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog") + + def test_list_schemas_with_pattern(self, handle, mock_spark): + """Test that list_schemas includes LIKE pattern.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") + + def test_list_tables_executes_show_tables(self, handle, mock_spark): + """Test that list_tables executes SHOW TABLES.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_tables("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW TABLES IN my_catalog.my_schema") + + def test_create_gets_or_creates_spark_session(self): + """Test that create uses getOrCreate.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + handle = DatabricksSessionHandle.create() + + mock_builder.getOrCreate.assert_called_once() + assert handle.session_id == "app-456" + + def test_create_sets_catalog(self): + """Test that create sets the catalog.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(catalog="my_catalog") + + mock_spark.catalog.setCurrentCatalog.assert_called_once_with("my_catalog") + + def test_create_sets_schema(self): + """Test that create sets the schema.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(schema="my_schema") + + mock_spark.catalog.setCurrentDatabase.assert_called_once_with("my_schema") + + def test_create_sets_session_properties(self): + """Test that create sets session properties.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(session_properties={"key1": "value1", "key2": 123}) + + mock_spark.conf.set.assert_any_call("key1", "value1") + mock_spark.conf.set.assert_any_call("key2", "123") diff --git a/tests/unit/test_session_credentials.py b/tests/unit/test_session_credentials.py new file mode 100644 index 000000000..3cc4432a2 --- /dev/null +++ b/tests/unit/test_session_credentials.py @@ -0,0 +1,236 @@ +"""Unit tests for session mode credentials.""" + +import os +from unittest.mock import patch + +import pytest +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError + +from dbt.adapters.databricks.credentials import ( + CONNECTION_METHOD_DBSQL, + CONNECTION_METHOD_SESSION, + DBT_DATABRICKS_SESSION_MODE_ENV, + DATABRICKS_RUNTIME_VERSION_ENV, + DatabricksCredentials, +) + + +class TestSessionModeAutoDetection: + """Tests for session mode auto-detection.""" + + def test_default_method_is_dbsql(self): + """Test that default method is dbsql when no env vars set.""" + with patch.dict(os.environ, {}, clear=True): + # Need to provide host/http_path for dbsql mode + creds = DatabricksCredentials( + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_explicit_session_method(self): + """Test that explicit method=session is respected.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_explicit_dbsql_method(self): + """Test that explicit method=dbsql is respected.""" + creds = DatabricksCredentials( + method="dbsql", + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_env_var_enables_session_mode(self): + """Test that DBT_DATABRICKS_SESSION_MODE=true enables session mode.""" + with patch.dict(os.environ, {DBT_DATABRICKS_SESSION_MODE_ENV: "true"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_env_var_case_insensitive(self): + """Test that DBT_DATABRICKS_SESSION_MODE is case insensitive.""" + with patch.dict(os.environ, {DBT_DATABRICKS_SESSION_MODE_ENV: "TRUE"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.is_session_mode is True + + def test_databricks_runtime_env_without_host_enables_session(self): + """Test that DATABRICKS_RUNTIME_VERSION without host enables session mode.""" + with patch.dict(os.environ, {DATABRICKS_RUNTIME_VERSION_ENV: "14.3.x-scala2.12"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_databricks_runtime_env_with_host_uses_dbsql(self): + """Test that DATABRICKS_RUNTIME_VERSION with host uses dbsql mode.""" + with patch.dict(os.environ, {DATABRICKS_RUNTIME_VERSION_ENV: "14.3.x-scala2.12"}): + creds = DatabricksCredentials( + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_invalid_method_raises_error(self): + """Test that invalid method raises DbtValidationError.""" + with pytest.raises(DbtValidationError) as exc_info: + DatabricksCredentials( + method="invalid", + schema="test_schema", + ) + assert "Invalid connection method" in str(exc_info.value) + + +class TestSessionModeValidation: + """Tests for session mode validation.""" + + def test_session_mode_requires_schema(self): + """Test that session mode requires schema.""" + with patch( + "dbt.adapters.databricks.credentials.DatabricksCredentials._validate_session_mode" + ) as mock_validate: + mock_validate.side_effect = DbtValidationError("Schema is required for session mode.") + with pytest.raises(DbtValidationError) as exc_info: + creds = DatabricksCredentials(method="session") + creds.validate_creds() + assert "Schema is required" in str(exc_info.value) + + def test_session_mode_does_not_require_host(self): + """Test that session mode does not require host.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - host is not required for session mode + assert creds.host is None + + def test_session_mode_does_not_require_http_path(self): + """Test that session mode does not require http_path.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - http_path is not required for session mode + assert creds.http_path is None + + def test_session_mode_does_not_require_token(self): + """Test that session mode does not require token.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - token is not required for session mode + assert creds.token is None + + +class TestSessionModeCredentialsManager: + """Tests for credentials manager in session mode.""" + + def test_session_mode_does_not_create_credentials_manager(self): + """Test that session mode does not create credentials manager.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + assert creds._credentials_manager is None + + def test_session_mode_authenticate_returns_none(self): + """Test that authenticate returns None in session mode.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Mock _validate_session_mode to avoid pyspark import + with patch.object(creds, "_validate_session_mode"): + result = creds.authenticate() + assert result is None + + +class TestSessionModeConnectionKeys: + """Tests for connection keys in session mode.""" + + def test_session_mode_connection_keys(self): + """Test that session mode has correct connection keys.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + database="main", + ) + keys = creds._connection_keys() + assert "method" in keys + assert "schema" in keys + assert "catalog" in keys + assert "host" not in keys + assert "http_path" not in keys + + def test_session_mode_unique_field(self): + """Test that session mode unique_field is based on catalog/schema.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + database="main", + ) + assert creds.unique_field == "session://main/test_schema" + + +class TestDbsqlModeValidation: + """Tests for DBSQL mode validation (existing behavior).""" + + def test_dbsql_mode_requires_host(self): + """Test that DBSQL mode requires host.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + http_path="/sql/1.0/warehouses/abc", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "host" in str(exc_info.value) + + def test_dbsql_mode_requires_http_path(self): + """Test that DBSQL mode requires http_path.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + host="my.databricks.com", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "http_path" in str(exc_info.value) + + def test_dbsql_mode_requires_token_or_oauth(self): + """Test that DBSQL mode requires token or oauth.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "oauth" in str(exc_info.value).lower() or "token" in str(exc_info.value).lower() diff --git a/tests/unit/test_session_python.py b/tests/unit/test_session_python.py new file mode 100644 index 000000000..0544ba449 --- /dev/null +++ b/tests/unit/test_session_python.py @@ -0,0 +1,189 @@ +"""Unit tests for session mode Python model submission.""" + +from unittest.mock import MagicMock, patch + +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.python_models.python_submissions import ( + SessionPythonJobHelper, + SessionPythonSubmitter, + SessionStateManager, +) + + +class TestSessionStateManager: + """Tests for SessionStateManager.""" + + def test_get_clean_exec_globals_includes_spark(self): + """Test that get_clean_exec_globals includes spark.""" + mock_spark = MagicMock() + + globals_dict = SessionStateManager.get_clean_exec_globals(mock_spark) + + assert globals_dict["spark"] is mock_spark + + def test_get_clean_exec_globals_includes_dbt(self): + """Test that get_clean_exec_globals includes dbt module.""" + mock_spark = MagicMock() + + globals_dict = SessionStateManager.get_clean_exec_globals(mock_spark) + + assert "dbt" in globals_dict + + def test_cleanup_temp_views_drops_temp_views(self): + """Test that cleanup_temp_views drops temporary views.""" + mock_spark = MagicMock() + mock_row = MagicMock() + mock_row.viewName = "temp_view_1" + mock_row.isTemporary = True + mock_spark.sql.return_value.collect.return_value = [mock_row] + + SessionStateManager.cleanup_temp_views(mock_spark) + + mock_spark.catalog.dropTempView.assert_called_once_with("temp_view_1") + + def test_cleanup_temp_views_ignores_non_temp_views(self): + """Test that cleanup_temp_views ignores non-temporary views.""" + mock_spark = MagicMock() + mock_row = MagicMock() + mock_row.viewName = "permanent_view" + mock_row.isTemporary = False + mock_spark.sql.return_value.collect.return_value = [mock_row] + + SessionStateManager.cleanup_temp_views(mock_spark) + + mock_spark.catalog.dropTempView.assert_not_called() + + +class TestSessionPythonSubmitter: + """Tests for SessionPythonSubmitter.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + return MagicMock() + + @pytest.fixture + def submitter(self, mock_spark): + """Create a SessionPythonSubmitter with mock SparkSession.""" + return SessionPythonSubmitter(mock_spark) + + def test_submit_executes_code(self, submitter, mock_spark): + """Test that submit executes the compiled code.""" + # Simple code that just sets a variable + compiled_code = "result = 1 + 1" + + submitter.submit(compiled_code) + + # The code should execute without error + + def test_submit_provides_spark_in_globals(self, submitter, mock_spark): + """Test that submit provides spark in the execution globals.""" + # Code that uses spark + compiled_code = "spark_app_name = spark.sparkContext.appName" + + mock_spark.sparkContext.appName = "test-app" + + submitter.submit(compiled_code) + + # The code should execute without error + + def test_submit_raises_on_execution_error(self, submitter, mock_spark): + """Test that submit raises DbtRuntimeError on execution error.""" + compiled_code = "raise ValueError('test error')" + + with pytest.raises(DbtRuntimeError) as exc_info: + submitter.submit(compiled_code) + + assert "Python model execution failed" in str(exc_info.value) + assert "test error" in str(exc_info.value) + + def test_submit_cleans_up_temp_views(self, submitter, mock_spark): + """Test that submit cleans up temp views after execution.""" + compiled_code = "result = 1" + mock_spark.sql.return_value.collect.return_value = [] + + submitter.submit(compiled_code) + + # Cleanup should be called (SHOW VIEWS) + mock_spark.sql.assert_called() + + +class TestSessionPythonJobHelper: + """Tests for SessionPythonJobHelper.""" + + @pytest.fixture + def mock_credentials(self): + """Create mock credentials.""" + creds = MagicMock() + creds.is_session_mode = True + return creds + + @pytest.fixture + def parsed_model_dict(self): + """Create a parsed model dictionary.""" + return { + "catalog": "main", + "schema": "test_schema", + "identifier": "test_model", + "config": { + "timeout": 3600, + "packages": [], + "index_url": None, + "additional_libs": [], + "python_job_config": {}, # Empty dict instead of None + "cluster_id": None, + "http_path": None, + "create_notebook": False, + "job_cluster_config": {}, # Empty dict instead of None + "access_control_list": [], + "notebook_access_control_list": [], + "user_folder_for_python": True, + "environment_key": None, + "environment_dependencies": [], + }, + } + + def test_init_gets_spark_session(self, mock_credentials, parsed_model_dict): + """Test that __init__ gets the SparkSession.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-123" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + helper = SessionPythonJobHelper(parsed_model_dict, mock_credentials) + + mock_builder.getOrCreate.assert_called_once() + assert helper._spark is mock_spark + + def test_submit_delegates_to_submitter(self, mock_credentials, parsed_model_dict): + """Test that submit delegates to the submitter.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-123" + mock_spark.sql.return_value.collect.return_value = [] + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + helper = SessionPythonJobHelper(parsed_model_dict, mock_credentials) + + compiled_code = "result = 1" + helper.submit(compiled_code) + + # The code should execute without error