diff --git a/PLAN_v2_session_mode.md b/PLAN_v2_session_mode.md new file mode 100644 index 000000000..8894b0c29 --- /dev/null +++ b/PLAN_v2_session_mode.md @@ -0,0 +1,556 @@ +# Implementation Plan v2: Full Job Cluster Support for SQL + Python Models + +## Summary + +Enable complete dbt pipelines (SQL + Python models) to execute entirely on Databricks job clusters by: +1. Adding `method: session` for SQL models using the active SparkSession +2. Adding `session` submission method for Python models executing directly in the same session + +**Key Benefits:** +- ๐ŸŽฏ **Primary Goal Achieved**: Run full dbt pipeline (SQL + Python) on a single job cluster +- ๐Ÿ’ฐ **70%+ Cost Savings**: Job clusters cost ~7ยข less than SQL Warehouses on Azure +- ๐Ÿ“Š **Better Observability**: Entire dbt run tracked in Databricks Jobs UI (vs. no tracking for SQL models before) +- โšก **Faster Execution**: No job submission overhead for Python models +- โœ… **Correct Behavior**: Sequential execution respecting DAG dependencies (as intended) +- ๐Ÿ”ง **Simpler Setup**: No API authentication or permissions management needed + +## Background + +**Current State:** +- SQL models require DBSQL connector (SQL Warehouses or All-Purpose clusters) +- Python models can use job clusters via Jobs API (but spawn separate jobs) +- No way to run a mixed pipeline entirely within a single job cluster session + +**Goal:** +- Run complete dbt pipelines (SQL + Python) on a single job cluster +- All models execute in the same SparkSession +- Preserve all dbt-databricks features (Unity Catalog, Streaming Tables, etc.) + +## Architecture Overview + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Databricks Job Cluster โ”‚ +โ”‚ โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ +โ”‚ โ”‚ SparkSession โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ”‚ SQL Models โ”€โ”€โ”€โ”€โ”€โ”€โ–บ spark.sql() โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ”‚ Python Models โ”€โ”€โ”€โ–บ exec() with spark context โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ +โ”‚ โ”‚ +โ”‚ profiles.yml: method: session โ”‚ +โ”‚ model config: submission_method: session โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +## Implementation Approach + +### Phase 1: Add Session Connection Method (SQL Models) + +**File: [credentials.py](dbt/adapters/databricks/credentials.py)** + +1. Add `method` field to `DatabricksCredentials`: +```python +@dataclass +class DatabricksCredentials(Credentials): + # ... existing fields ... + method: Optional[str] = None # "session" or "dbsql" (default) +``` + +2. Update validation in `__post_init__`: +```python +# Auto-detect session mode +if self.method is None: + if os.getenv("DBT_DATABRICKS_SESSION_MODE", "").lower() == "true": + self.method = "session" + elif os.getenv("DATABRICKS_RUNTIME_VERSION") and not self.host: + self.method = "session" + else: + self.method = "dbsql" + +# Validate based on method +if self.method == "session": + self._validate_session_mode() +else: + self.validate_creds() +``` + +3. Add session validation: +```python +def _validate_session_mode(self) -> None: + try: + from pyspark.sql import SparkSession # noqa + except ImportError: + raise DbtRuntimeError("Session mode requires pyspark") + if self.schema is None: + raise DbtValidationError("Schema is required for session mode") +``` + +### Phase 2: Create Session Handle Module (SQL Execution) + +**New File: [session.py](dbt/adapters/databricks/session.py)** + +1. `SessionCursorWrapper` - Adapts DataFrame to cursor interface: +```python +class SessionCursorWrapper: + def __init__(self, spark: SparkSession): + self._spark = spark + self._df: Optional[DataFrame] = None + self._rows: Optional[list[Row]] = None + + def execute(self, sql: str, bindings=None) -> "SessionCursorWrapper": + cleaned_sql = sql.strip().rstrip(";") + if bindings: + cleaned_sql = cleaned_sql % tuple(bindings) + self._df = self._spark.sql(cleaned_sql) + return self + + def fetchall(self) -> list[tuple]: + if self._rows is None and self._df: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + @property + def description(self) -> list[tuple]: + if self._df is None: + return [] + return [(f.name, f.dataType.simpleString(), None, None, None, None, f.nullable) + for f in self._df.schema.fields] +``` + +2. `DatabricksSessionHandle` - Wraps SparkSession: +```python +class DatabricksSessionHandle: + def __init__(self, spark: SparkSession): + self._spark = spark + + @staticmethod + def create(catalog=None, schema=None, session_properties=None): + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + if catalog: + spark.catalog.setCurrentCatalog(catalog) + if schema: + spark.catalog.setCurrentDatabase(schema) + if session_properties: + for k, v in session_properties.items(): + spark.conf.set(k, v) + return DatabricksSessionHandle(spark) + + @property + def dbr_version(self) -> tuple[int, int]: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion" + ) + return SqlUtils.extract_dbr_version(version_str) + + def execute(self, sql, bindings=None) -> SessionCursorWrapper: + return SessionCursorWrapper(self._spark).execute(sql, bindings) +``` + +### Phase 3: Update Connection Manager + +**File: [connections.py](dbt/adapters/databricks/connections.py)** + +1. Modify `open()` to dispatch based on method: +```python +@classmethod +def open(cls, connection: Connection) -> Connection: + creds: DatabricksCredentials = connection.credentials + + if creds.method == "session": + return cls._open_session(connection, creds, databricks_connection) + + # Existing DBSQL logic... +``` + +2. Add `_open_session()` method for session-based connections. + +### Phase 4: Add Session Python Submission Helper (Python Models) + +**File: [python_submissions.py](dbt/adapters/databricks/python_models/python_submissions.py)** + +1. Add new `SessionPythonSubmitter`: +```python +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: SparkSession): + self._spark = spark + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + # Create execution context with spark available + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec_globals = { + "spark": self._spark, + "dbt": __import__("dbt"), + } + + # Execute the compiled Python model code + exec(compiled_code, exec_globals) +``` + +2. Add session state cleanup utilities: +```python +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: SparkSession) -> None: + """Drop temporary views created during model execution.""" + # Get list of temp views + temp_views = [row.name for row in spark.sql("SHOW VIEWS").collect() + if row.isTemporary] + for view in temp_views: + spark.catalog.dropTempView(view) + + @staticmethod + def get_clean_exec_globals(spark: SparkSession) -> dict: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Add other safe imports as needed + } +``` + +3. Update `SessionPythonSubmitter` with error handling: +```python +class SessionPythonSubmitter(PythonSubmitter): + def __init__(self, spark: SparkSession): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) + + # Execute the compiled Python model code + exec(compiled_code, exec_globals) + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") +``` + +4. Add new `SessionPythonJobHelper`: +```python +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode.""" + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly + from pyspark.sql import SparkSession + self._spark = SparkSession.builder.getOrCreate() + + def submit(self, compiled_code: str) -> None: + submitter = SessionPythonSubmitter(self._spark) + submitter.submit(compiled_code) +``` + +**File: [impl.py](dbt/adapters/databricks/impl.py)** + +3. Register the new submission helper: +```python +@property +def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: + return { + "job_cluster": JobClusterPythonJobHelper, + "all_purpose_cluster": AllPurposeClusterPythonJobHelper, + "serverless_cluster": ServerlessClusterPythonJobHelper, + "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, # NEW + } +``` + +4. Update `submit_python_job()` to auto-select session mode: +```python +def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterResponse: + # Auto-select session submission when in session mode + creds = self.config.credentials + if creds.method == "session": + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + + # ... existing code ... + return super().submit_python_job(parsed_model, compiled_code) +``` + +### Phase 5: Testing + +**New Files:** +- `tests/unit/test_session.py` - Unit tests for session handle and cursor +- `tests/unit/test_session_python.py` - Unit tests for session Python submission +- `tests/functional/adapter/session/test_session_mixed_pipeline.py` - Integration tests + +## Files to Modify + +| File | Changes | +|------|---------| +| [credentials.py](dbt/adapters/databricks/credentials.py) | Add `method` field, auto-detection, validation | +| [connections.py](dbt/adapters/databricks/connections.py) | Add session dispatch in `open()`, capabilities caching | +| [impl.py](dbt/adapters/databricks/impl.py) | Register `SessionPythonJobHelper`, auto-select session mode | +| [python_submissions.py](dbt/adapters/databricks/python_models/python_submissions.py) | Add `SessionPythonSubmitter`, `SessionPythonJobHelper` | +| **NEW** [session.py](dbt/adapters/databricks/session.py) | `SessionCursorWrapper`, `DatabricksSessionHandle` | +| **NEW** [tests/unit/test_session.py](tests/unit/test_session.py) | Unit tests | + +## Configuration + +**profiles.yml for full session mode:** +```yaml +my_project: + target: job_cluster + outputs: + job_cluster: + type: databricks + method: session + catalog: main + schema: my_schema + # host/http_path/token NOT required +``` + +**Model-level override (optional):** +```sql +-- For Python models that need specific submission method +{{ config(submission_method='session') }} +``` + +**Environment variable:** +```bash +export DBT_DATABRICKS_SESSION_MODE=true +``` + +## Usage Scenarios + +### Scenario 1: Run dbt from Databricks Notebook on Job Cluster +```python +# In a Databricks notebook task within a job +from dbt.cli.main import dbtRunner + +dbt = dbtRunner() +result = dbt.invoke(["run"]) # Uses session mode automatically +``` + +### Scenario 2: Run dbt from Python Script Task +```python +# python_task.py - executed as a Python script task in a Databricks job +import subprocess +subprocess.run(["dbt", "run", "--profiles-dir", "/dbfs/path/to/profiles"]) +``` + +### Scenario 3: Databricks Workflow with dbt Task +```yaml +# Databricks job definition +tasks: + - task_key: run_dbt + job_cluster_key: my_cluster + python_wheel_task: + package_name: my_dbt_project + entry_point: run_dbt + libraries: + - pypi: {package: dbt-databricks} +``` + +## Verification Plan + +1. **Unit tests**: `pytest tests/unit/test_session*.py` +2. **Local validation**: Test credentials validation for both modes +3. **Integration test**: + ```python + # Run in Databricks notebook on job cluster + from dbt.cli.main import dbtRunner + + dbt = dbtRunner() + + # Test SQL model + result = dbt.invoke(["run", "--select", "my_sql_model"]) + assert result.success + + # Test Python model + result = dbt.invoke(["run", "--select", "my_python_model"]) + assert result.success + + # Test full pipeline + result = dbt.invoke(["run"]) + assert result.success + ``` + +## Session vs API Client Comparison + +### Data Collection (Python Models) + +**Both approaches write directly to tables without collecting to driver:** + +1. **API Client (Jobs/Command API)**: + - Executes: `df.write.saveAsTable("table_name")` + - Runs in separate context/notebook + - Data written directly to storage + +2. **Session Mode (exec())**: + - Executes: `df.write.saveAsTable("table_name")` (same code!) + - Runs in same SparkSession + - Data written directly to storage + +**Result**: No difference in data collection behavior - both are distributed writes. + +### Functional Differences + +| Feature | API Client | Session Mode | +|---------|-----------|--------------| +| **Execution Location** | Separate job/context | Same SparkSession | +| **Isolation** | Isolated execution | Shared session state | +| **Retry Logic** | Built-in via Jobs API | Manual if needed | +| **Monitoring** | Databricks job UI | SparkUI only | +| **Permissions** | Requires cluster access permissions | Uses current session | +| **Async Execution** | Submits and polls | Blocks until complete | +| **Library Installation** | Via cluster config or job spec | Must be pre-installed | +| **Resource Limits** | Can specify per-job | Uses session limits | + +### Session Mode vs API Client Trade-offs + +**IMPORTANT CONTEXT**: These differences only apply to **Python model execution**. SQL models have never had these API client features. Session mode actually **improves** observability by allowing the entire dbt pipeline to run within a Databricks job. + +| Aspect | SQL Models (Before) | Session Mode (SQL + Python) | API Client (Python only) | +|--------|---------------------|----------------------------|--------------------------| +| **Pipeline Observability** | โŒ No Databricks job tracking | โœ… **Full dbt run tracked in Databricks job** | โœ… Individual Python models tracked | +| **Cost** | Requires SQL Warehouse/All-Purpose | โœ… **Single job cluster** | Multiple job clusters | +| **Sequential Execution** | โœ… Follows DAG | โœ… **Follows DAG (intended)** | โœ… Follows DAG | +| **Per-Model Monitoring** | โŒ Not in Databricks UI | โŒ Not in Databricks UI | โœ… Each Python model visible | +| **Library Management** | Pre-installed | Pre-installed | โœ… Per-job dynamic install | +| **Failure Behavior** | Fails dbt run | Fails dbt run | โœ… Per-model retry | + +**Key Insight**: The "drawbacks" listed are actually **not regressions** - SQL models never had per-model Databricks job monitoring, dynamic libraries, or isolated retry logic. Session mode brings: + +โœ… **New Capability**: Run entire dbt pipeline on job cluster (previously impossible) +โœ… **Better Observability**: Whole dbt run visible in Databricks job (vs. no tracking before) +โœ… **Cost Savings**: Single cluster for SQL + Python (vs. SQL Warehouse + separate job clusters) +โœ… **Correct Behavior**: Sequential execution following data lineage (as intended) + +**The Only Real Trade-off**: Per-Python-model granularity in Databricks UI vs. whole-pipeline execution efficiency. + +### When to Use Each Approach + +**Use Session Mode (Recommended for most use cases):** +- โœ… Running dbt from Databricks job (notebooks, tasks, workflows) +- โœ… Cost optimization is priority (single cluster for entire pipeline) +- โœ… Want Databricks job-level tracking of dbt runs +- โœ… Sequential execution following DAG is acceptable (it should be!) +- โœ… Libraries can be installed on cluster beforehand + +**Use API Client (for specific advanced scenarios):** +- Need per-Python-model granular monitoring in Databricks UI +- Require per-model dynamic library installation +- Want independent retry logic for each Python model +- Need to run Python models on different cluster configurations +- Async submission of Python models (non-DAG execution) + +## Limitations + +**Actual Limitations:** +- **No API-based features**: Session mode doesn't have host/token, so API-dependent features (like standalone workflow job creation) won't work +- **No DBSQL query history**: Queries don't appear in Databricks SQL query history (but entire run appears in job logs) +- **Single cluster compute**: All models use same cluster resources (this is the goal for cost optimization) + +**Not Limitations (these are correct behaviors):** +- โœ… Python models don't appear as separate jobs in Databricks UI โ†’ **Correct**: They're part of the main dbt job +- โœ… Models execute sequentially following DAG โ†’ **Correct**: This respects data dependencies +- โœ… Libraries must be pre-installed โ†’ **Normal**: Same as SQL models; install on cluster startup + +## Design Decisions + +1. **Configuration**: Use `method: session` in profiles.yml (aligns with dbt-spark) +2. **Python execution**: Direct `exec()` in current SparkSession (simplest, most compatible) +3. **Auto-detection**: When `DATABRICKS_RUNTIME_VERSION` is set and no host configured, default to session mode +4. **Backward compatibility**: Default behavior unchanged when `method` not specified with host/token present + +## Migration Path + +For existing users: +1. No changes required if using DBSQL mode (default) +2. To use session mode, add `method: session` to profile +3. Python models automatically use session submission when profile is in session mode + +## FAQ + +### Q: When will Python model results be collected to the driver? + +**A: Never.** Both session mode and API client mode execute the same compiled code, which ends with: +```python +df.write.saveAsTable("table_name") +``` + +This writes the DataFrame directly to Delta/storage in a distributed manner. No `collect()` is called, so data stays distributed across the cluster. The only data transferred is metadata (schema, row counts, etc.). + +### Q: What are the drawbacks of not using the API client for Python models? + +**A: The "drawbacks" are actually not regressions - they're trade-offs that already existed for SQL models:** + +**What you "lose" (but SQL models never had):** +1. Per-Python-model visibility in Databricks Jobs UI (vs. per-notebook jobs) +2. Per-Python-model retry logic (vs. whole dbt run retry) +3. Dynamic library installation per model + +**What you GAIN with session mode:** +1. โœ… **Run entire pipeline on job cluster** (previously impossible) +2. โœ… **Databricks job tracking of full dbt run** (better than no tracking) +3. โœ… **70%+ cost savings** - Single cluster vs. SQL Warehouse + job clusters +4. โœ… **Faster execution** - No job submission overhead +5. โœ… **Correct DAG execution** - Sequential following lineage (as intended) +6. โœ… **Simpler setup** - No API authentication needed + +**Context**: SQL models have always executed sequentially, blocked downstream models, and lacked per-model Databricks UI tracking. Session mode brings Python models to the same (correct) behavior while enabling the entire pipeline to run on cost-effective job clusters. + +**Recommendation**: Use session mode for production pipelines running from Databricks jobs. It's more cost-effective, properly respects the DAG, and provides job-level observability. + +### Q: How do I handle failures in session mode? + +**A: Options:** + +1. **Databricks job-level retry**: Configure the Databricks job that runs dbt to retry on failure +2. **dbt retry logic**: Use `dbt retry` command after failures +3. **Model-level error handling**: Add try/except in Python model code if needed +4. **Checkpointing**: Use incremental models to avoid re-running successful work + +### Q: Can I mix session mode and API client mode? + +**A: Not recommended.** Stick to one approach per dbt run for consistency. However, you could: +- Use session mode for dev/testing +- Use API client mode for production +- Configure via environment-specific profiles + +## Future Considerations + +- Consider submitting as PR to dbt-databricks upstream +- May need periodic sync with dbt-databricks updates if maintained as fork +- Could add support for hybrid mode (session SQL + Jobs API Python) if needed +- Consider adding session-level Python model parallelization using ThreadPoolExecutor diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 291ac7677..cbbd171d3 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -49,6 +49,7 @@ from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt.adapters.databricks.session import DatabricksSessionHandle, SessionCursorWrapper from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials if TYPE_CHECKING: @@ -150,6 +151,8 @@ class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_manager: Optional[DatabricksCredentialManager] = None _dbr_capabilities_cache: dict[str, DBRCapabilities] = {} + # Cache for session mode capabilities (keyed by "session") + _session_capabilities: Optional[DBRCapabilities] = None def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) @@ -159,9 +162,19 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): def api_client(self) -> DatabricksApiClient: if self._api_client is None: credentials = cast(DatabricksCredentials, self.profile.credentials) + if credentials.is_session_mode: + raise DbtRuntimeError( + "API client is not available in session mode. " + "Session mode does not support API-based operations." + ) self._api_client = DatabricksApiClient(credentials, 15 * 60) return self._api_client + def is_session_mode(self) -> bool: + """Check if the connection is using session mode.""" + credentials = cast(DatabricksCredentials, self.profile.credentials) + return credentials.is_session_mode + def is_cluster(self) -> bool: conn = self.get_thread_connection() databricks_conn = cast(DatabricksDBTConnection, conn) @@ -210,14 +223,16 @@ def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) - def cancel_open(self) -> list[str]: cancelled = super().cancel_open() - logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(self.api_client) + # Only cancel Python jobs via API if not in session mode + if not self.is_session_mode(): + logger.info("Cancelling open python jobs") + PythonRunTracker.cancel_runs(self.api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) - handle: DatabricksHandle = self.get_thread_connection().handle + handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection().handle dbr_version = handle.dbr_version return (dbr_version > version) - (dbr_version < version) @@ -321,7 +336,7 @@ def add_query( fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: log_sql = redact_credentials(sql) if abridge_sql_log: @@ -337,7 +352,7 @@ def add_query( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = handle.execute(sql, bindings) response = self.get_response(cursor) fire_event( @@ -380,14 +395,18 @@ def execute( cursor.close() def _execute_with_cursor( - self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper] + self, + log_sql: str, + f: Callable[ + [DatabricksHandle | DatabricksSessionHandle], CursorWrapper | SessionCursorWrapper + ], ) -> "Table": connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name))) with self.exception_handler(log_sql): - cursor: Optional[CursorWrapper] = None + cursor: Optional[CursorWrapper | SessionCursorWrapper] = None try: fire_event( SQLQuery( @@ -399,7 +418,7 @@ def _execute_with_cursor( pre = time.time() - handle: DatabricksHandle = connection.handle + handle: DatabricksHandle | DatabricksSessionHandle = connection.handle cursor = f(handle) response = self.get_response(cursor) @@ -464,9 +483,66 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials + + # Dispatch based on connection method + if creds.is_session_mode: + return cls._open_session(databricks_connection, creds) + else: + return cls._open_dbsql(databricks_connection, creds) + + @classmethod + def _open_session( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using SparkSession mode.""" + logger.debug("Opening connection in session mode") + + def connect() -> DatabricksSessionHandle: + try: + handle = DatabricksSessionHandle.create( + catalog=creds.database, + schema=creds.schema, + session_properties=creds.session_properties, + ) + databricks_connection.session_id = handle.session_id + + # Cache capabilities for session mode + cls._cache_session_capabilities(handle) + databricks_connection.capabilities = cls._session_capabilities or DBRCapabilities() + + logger.debug(f"Session mode connection opened: {handle}") + return handle + except Exception as exc: + logger.error(ConnectionCreateError(exc)) + raise DbtDatabaseError(f"Failed to create session connection: {exc}") from exc + + # Session mode doesn't need retry logic as SparkSession is already available + databricks_connection.handle = connect() + databricks_connection.state = ConnectionState.OPEN + return databricks_connection + + @classmethod + def _cache_session_capabilities(cls, handle: DatabricksSessionHandle) -> None: + """Cache DBR capabilities for session mode.""" + if cls._session_capabilities is None: + dbr_version = handle.dbr_version + cls._session_capabilities = DBRCapabilities( + dbr_version=dbr_version, + is_sql_warehouse=False, # Session mode is always on a cluster + ) + logger.debug(f"Cached session capabilities: DBR version {dbr_version}") + + @classmethod + def _open_dbsql( + cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials + ) -> Connection: + """Open a connection using DBSQL connector.""" timeout = creds.connect_timeout - cls.credentials_manager = creds.authenticate() + credentials_manager = creds.authenticate() + # In DBSQL mode, authenticate() always returns a credentials manager + assert credentials_manager is not None, "Credentials manager is required for DBSQL mode" + cls.credentials_manager = credentials_manager # Get merged query tags if we have query header context query_header_context = getattr(databricks_connection, "_query_header_context", None) @@ -475,7 +551,7 @@ def open(cls, connection: Connection) -> Connection: merged_query_tags = QueryConfigUtils.get_merged_query_tags(query_header_context, creds) conn_args = SqlUtils.prepare_connection_arguments( - creds, cls.credentials_manager, databricks_connection.http_path, merged_query_tags + creds, credentials_manager, databricks_connection.http_path, merged_query_tags ) def connect() -> DatabricksHandle: @@ -507,7 +583,7 @@ def exponential_backoff(attempt: int) -> int: retryable_exceptions = [Error] return cls.retry_connection( - connection, + databricks_connection, connect=connect, logger=logger, retryable_exceptions=retryable_exceptions, @@ -527,7 +603,7 @@ def close(cls, connection: Connection) -> Connection: @classmethod def get_response(cls, cursor: Any) -> AdapterResponse: - if isinstance(cursor, CursorWrapper): + if isinstance(cursor, (CursorWrapper, SessionCursorWrapper)): return cursor.get_response() else: return AdapterResponse("OK") diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 7e8af786b..d0a9745d1 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -1,12 +1,13 @@ import itertools import json +import os import re from collections.abc import Iterable from dataclasses import dataclass, field from typing import Any, Callable, Optional, cast from dbt.adapters.contracts.connection import Credentials -from dbt_common.exceptions import DbtConfigError, DbtValidationError +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError from mashumaro import DataClassDictMixin from requests import PreparedRequest from requests.auth import AuthBase @@ -16,6 +17,14 @@ from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger +# Connection method constants +CONNECTION_METHOD_SESSION = "session" +CONNECTION_METHOD_DBSQL = "dbsql" + +# Environment variable for session mode +DBT_DATABRICKS_SESSION_MODE_ENV = "DBT_DATABRICKS_SESSION_MODE" +DATABRICKS_RUNTIME_VERSION_ENV = "DATABRICKS_RUNTIME_VERSION" + CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") @@ -48,6 +57,9 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[dict[str, Any]] = None auth_type: Optional[str] = None + # Connection method: "session" for SparkSession mode, "dbsql" for DBSQL connector (default) + method: Optional[str] = None + # Named compute resources specified in the profile. Used for # creating a connection when a model specifies a compute resource. compute: Optional[dict[str, Any]] = None @@ -102,6 +114,9 @@ def __post_init__(self) -> None: else: self.database = "hive_metastore" + # Auto-detect and validate connection method + self._init_connection_method() + connection_parameters = self.connection_parameters or {} for key in ( "server_hostname", @@ -130,9 +145,57 @@ def __post_init__(self) -> None: if "_socket_timeout" not in connection_parameters: connection_parameters["_socket_timeout"] = 600 self.connection_parameters = connection_parameters - self._credentials_manager = DatabricksCredentialManager.create_from(self) + + # Only create credentials manager for non-session mode + if not self.is_session_mode: + self._credentials_manager = DatabricksCredentialManager.create_from(self) + + def _init_connection_method(self) -> None: + """Initialize and validate the connection method.""" + if self.method is None: + # Auto-detect session mode + if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, "").lower() == "true": + self.method = CONNECTION_METHOD_SESSION + elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host: + # Running on Databricks cluster without host configured + self.method = CONNECTION_METHOD_SESSION + else: + self.method = CONNECTION_METHOD_DBSQL + + # Validate method value + if self.method not in (CONNECTION_METHOD_SESSION, CONNECTION_METHOD_DBSQL): + raise DbtValidationError( + f"Invalid connection method: '{self.method}'. " + f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'." + ) + + @property + def is_session_mode(self) -> bool: + """Check if using session mode (SparkSession) for connections.""" + return self.method == CONNECTION_METHOD_SESSION + + def _validate_session_mode(self) -> None: + """Validate configuration for session mode.""" + try: + from pyspark.sql import SparkSession # noqa: F401 + except ImportError: + raise DbtRuntimeError( + "Session mode requires PySpark. " + "Please ensure you are running on a Databricks cluster with PySpark available." + ) + + if self.schema is None: + raise DbtValidationError("Schema is required for session mode.") def validate_creds(self) -> None: + """Validate credentials based on connection method.""" + if self.is_session_mode: + self._validate_session_mode() + else: + self._validate_dbsql_creds() + + def _validate_dbsql_creds(self) -> None: + """Validate credentials for DBSQL connector mode.""" for key in ["host", "http_path"]: if not getattr(self, key): raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") @@ -197,6 +260,9 @@ def type(self) -> str: @property def unique_field(self) -> str: + if self.is_session_mode: + # For session mode, use a unique identifier based on catalog and schema + return f"session://{self.database}/{self.schema}" return cast(str, self.host) def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]: @@ -209,6 +275,15 @@ def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, if key in as_dict: yield key, as_dict[key] + def _connection_keys_session(self) -> tuple[str, ...]: + """Connection keys for session mode.""" + connection_keys = ["method", "schema"] + if self.database: + connection_keys.insert(1, "catalog") + if self.session_properties: + connection_keys.append("session_properties") + return tuple(connection_keys) + def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)` # is called from only: @@ -218,6 +293,11 @@ def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]: # # Thus, if `with_aliases` is `True`, `DatabricksCredentials._connection_keys` should return # the internal key names; otherwise it can use aliases to show in `dbt debug`. + + # Session mode has different connection keys + if self.is_session_mode: + return self._connection_keys_session() + connection_keys = ["host", "http_path", "schema"] if with_aliases: connection_keys.insert(2, "database") @@ -239,8 +319,15 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] - def authenticate(self) -> "DatabricksCredentialManager": + def authenticate(self) -> Optional["DatabricksCredentialManager"]: + """Authenticate and return credentials manager. + + For session mode, returns None as no external authentication is needed. + For DBSQL mode, validates credentials and returns the credentials manager. + """ self.validate_creds() + if self.is_session_mode: + return None assert self._credentials_manager is not None, "Credentials manager is not set." return self._credentials_manager diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 2190e1f9c..efbe51f55 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -50,13 +50,16 @@ DatabricksConnectionManager, DatabricksDBTConnection, ) +from dbt.adapters.databricks.credentials import DatabricksCredentials from dbt.adapters.databricks.dbr_capabilities import DBRCapability from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, ServerlessClusterPythonJobHelper, + SessionPythonJobHelper, WorkflowPythonJobHelper, ) from dbt.adapters.databricks.relation import ( @@ -801,6 +804,7 @@ def python_submission_helpers(self) -> dict[str, type[PythonJobHelper]]: "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, "workflow_job": WorkflowPythonJobHelper, + "session": SessionPythonJobHelper, } @log_code_execution @@ -809,6 +813,14 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterRe "user_folder_for_python", self.behavior.use_user_folder_for_python.setting, # type: ignore[attr-defined] ) + + # Auto-select session submission when in session mode + creds = cast(DatabricksCredentials, self.config.credentials) + if creds.is_session_mode: + if parsed_model["config"].get("submission_method") is None: + parsed_model["config"]["submission_method"] = "session" + logger.debug("Auto-selected 'session' submission method for Python model") + return super().submit_python_job(parsed_model, compiled_code) @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index 23c2a27c1..a04ef8db8 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from dbt.adapters.base import PythonJobHelper from dbt_common.exceptions import DbtRuntimeError @@ -12,6 +12,9 @@ from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +if TYPE_CHECKING: + from pyspark.sql import SparkSession + DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -658,3 +661,111 @@ def build_submitter(self) -> PythonSubmitter: return PythonNotebookWorkflowSubmitter.create( self.api_client, self.tracker, self.parsed_model ) + + +class SessionStateManager: + """Manages session state to prevent leakage between Python models.""" + + @staticmethod + def cleanup_temp_views(spark: "SparkSession") -> None: + """Drop temporary views created during model execution.""" + try: + # Get list of temp views from the current database + temp_views = [ + row.viewName + for row in spark.sql("SHOW VIEWS").collect() + if hasattr(row, "isTemporary") and row.isTemporary + ] + for view in temp_views: + try: + spark.catalog.dropTempView(view) + logger.debug(f"Dropped temp view: {view}") + except Exception as e: + logger.warning(f"Failed to drop temp view {view}: {e}") + except Exception as e: + logger.debug(f"Could not list temp views for cleanup: {e}") + + @staticmethod + def get_clean_exec_globals(spark: "SparkSession") -> dict[str, Any]: + """Return a clean execution context with minimal state.""" + return { + "spark": spark, + "dbt": __import__("dbt"), + # Standard Python builtins are available by default + } + + +class SessionPythonSubmitter(PythonSubmitter): + """Submitter for Python models using direct execution in current SparkSession. + + NOTE: This does NOT collect data to the driver. The compiled code contains + df.write.saveAsTable() which writes directly to storage, just like API-based + submission methods. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._state_manager = SessionStateManager() + + @override + def submit(self, compiled_code: str) -> None: + logger.debug("Executing Python model directly in SparkSession.") + + try: + # Get clean execution context + exec_globals = self._state_manager.get_clean_exec_globals(self._spark) + + # Log a preview of the code being executed + preview_len = min(500, len(compiled_code)) + logger.debug( + f"[Session Python] Executing code preview: {compiled_code[:preview_len]}..." + ) + + # Execute the compiled Python model code + # The compiled code will: + # 1. Execute model() function to get a DataFrame + # 2. Call df.write.saveAsTable() to persist to Delta + # 3. No collect() - data stays distributed + exec(compiled_code, exec_globals) + + logger.debug("[Session Python] Model execution completed successfully") + + except Exception as e: + logger.error(f"Python model execution failed: {e}") + raise DbtRuntimeError(f"Python model execution failed: {e}") from e + + finally: + # Clean up temp views to prevent state leakage + try: + self._state_manager.cleanup_temp_views(self._spark) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup temp views: {cleanup_error}") + + +class SessionPythonJobHelper(PythonJobHelper): + """Helper for Python models executing directly in session mode. + + This helper executes Python models directly in the current SparkSession + without using the Databricks API. It's designed for running dbt on + job clusters where the SparkSession is already available. + """ + + tracker = PythonRunTracker() + + def __init__(self, parsed_model: dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.parsed_model = ParsedPythonModel(**parsed_model) + + # Get SparkSession directly - no API client needed + from pyspark.sql import SparkSession + + self._spark = SparkSession.builder.getOrCreate() + logger.debug( + f"[Session Python] Using SparkSession: {self._spark.sparkContext.applicationId}" + ) + + self._submitter = SessionPythonSubmitter(self._spark) + + def submit(self, compiled_code: str) -> None: + """Submit the compiled Python model for execution.""" + self._submitter.submit(compiled_code) diff --git a/dbt/adapters/databricks/session.py b/dbt/adapters/databricks/session.py new file mode 100644 index 000000000..68df6b046 --- /dev/null +++ b/dbt/adapters/databricks/session.py @@ -0,0 +1,282 @@ +""" +Session mode support for dbt-databricks. + +This module provides SparkSession-based execution for running dbt on Databricks job clusters +without requiring the DBSQL connector. It enables complete dbt pipelines (SQL + Python models) +to execute entirely within a single SparkSession. + +Key components: +- SessionCursorWrapper: Adapts DataFrame results to the cursor interface expected by dbt +- DatabricksSessionHandle: Wraps SparkSession to provide the handle interface +""" + +import sys +from collections.abc import Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional + +from dbt.adapters.contracts.connection import AdapterResponse +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.handle import SqlUtils +from dbt.adapters.databricks.logging import logger + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, Row, SparkSession + from pyspark.sql.types import StructField + + +class SessionCursorWrapper: + """ + Wraps SparkSession DataFrame results to provide a cursor-like interface. + + This adapter allows dbt to use SparkSession.sql() results in the same way + it uses DBSQL cursor results, maintaining compatibility with the existing + connection management code. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self._df: Optional["DataFrame"] = None + self._rows: Optional[list["Row"]] = None + self._query_id: str = "session-query" + self.open = True + + def execute( + self, sql: str, bindings: Optional[Sequence[Any]] = None + ) -> "SessionCursorWrapper": + """Execute a SQL statement and store the resulting DataFrame.""" + cleaned_sql = SqlUtils.clean_sql(sql) + + # Handle bindings by simple string substitution if provided + if bindings: + translated = SqlUtils.translate_bindings(bindings) + if translated: + cleaned_sql = cleaned_sql % tuple(translated) + + logger.debug(f"Session mode executing SQL: {cleaned_sql[:200]}...") + self._df = self._spark.sql(cleaned_sql) + self._rows = None # Reset cached rows + return self + + def fetchall(self) -> Sequence[tuple]: + """Fetch all rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + return [tuple(row) for row in (self._rows or [])] + + def fetchone(self) -> Optional[tuple]: + """Fetch the next row from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if self._rows: + return tuple(self._rows.pop(0)) + return None + + def fetchmany(self, size: int) -> Sequence[tuple]: + """Fetch the next `size` rows from the result set.""" + if self._rows is None and self._df is not None: + self._rows = self._df.collect() + if not self._rows: + return [] + result = [tuple(row) for row in self._rows[:size]] + self._rows = self._rows[size:] + return result + + @property + def description(self) -> Optional[list[tuple]]: + """Return column descriptions in DB-API format.""" + if self._df is None: + return None + return [self._field_to_description(f) for f in self._df.schema.fields] + + @staticmethod + def _field_to_description(field: "StructField") -> tuple: + """Convert a StructField to a DB-API description tuple.""" + # DB-API description: (name, type_code, display_size, internal_size, + # precision, scale, null_ok) + return ( + field.name, + field.dataType.simpleString(), + None, + None, + None, + None, + field.nullable, + ) + + def get_response(self) -> AdapterResponse: + """Return an adapter response for the executed query.""" + return AdapterResponse(_message="OK", query_id=self._query_id) + + def cancel(self) -> None: + """Cancel the current operation (no-op for session mode).""" + logger.debug("SessionCursorWrapper.cancel() called (no-op)") + self.open = False + + def close(self) -> None: + """Close the cursor.""" + logger.debug("SessionCursorWrapper.close() called") + self.open = False + self._df = None + self._rows = None + + def __str__(self) -> str: + return f"SessionCursor(query-id={self._query_id})" + + def __enter__(self) -> "SessionCursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + + +class DatabricksSessionHandle: + """ + Handle for a Databricks SparkSession. + + Provides the same interface as DatabricksHandle but uses the active SparkSession + instead of the DBSQL connector. This enables dbt to run on job clusters without + requiring external API connections. + """ + + def __init__(self, spark: "SparkSession"): + self._spark = spark + self.open = True + self._cursor: Optional[SessionCursorWrapper] = None + self._dbr_version: Optional[tuple[int, int]] = None + + @staticmethod + def create( + catalog: Optional[str] = None, + schema: Optional[str] = None, + session_properties: Optional[dict[str, Any]] = None, + ) -> "DatabricksSessionHandle": + """ + Create a DatabricksSessionHandle using the active SparkSession. + + Args: + catalog: Optional catalog to set as current + schema: Optional schema to set as current database + session_properties: Optional session configuration properties + + Returns: + A new DatabricksSessionHandle instance + """ + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + + # Set catalog if provided + if catalog: + try: + spark.catalog.setCurrentCatalog(catalog) + logger.debug(f"Set current catalog to: {catalog}") + except Exception as e: + logger.warning(f"Failed to set catalog '{catalog}': {e}") + # Fall back to USE CATALOG for older Spark versions + spark.sql(f"USE CATALOG {catalog}") + + # Set schema/database if provided + if schema: + spark.catalog.setCurrentDatabase(schema) + logger.debug(f"Set current database to: {schema}") + + # Apply session properties + if session_properties: + for key, value in session_properties.items(): + spark.conf.set(key, str(value)) + logger.debug(f"Set session property {key}={value}") + + handle = DatabricksSessionHandle(spark) + logger.debug(f"Created session handle: {handle}") + return handle + + @property + def dbr_version(self) -> tuple[int, int]: + """Get the DBR version of the current cluster.""" + if self._dbr_version is None: + try: + version_str = self._spark.conf.get( + "spark.databricks.clusterUsageTags.sparkVersion", "" + ) + if version_str: + self._dbr_version = SqlUtils.extract_dbr_version(version_str) + else: + # If we can't get the version, assume latest + logger.warning("Could not determine DBR version, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + except Exception as e: + logger.warning(f"Failed to get DBR version: {e}, assuming latest") + self._dbr_version = (sys.maxsize, sys.maxsize) + return self._dbr_version + + @property + def session_id(self) -> str: + """Get a unique identifier for this session.""" + try: + # Try to get the Spark application ID as a session identifier + return self._spark.sparkContext.applicationId or "session-unknown" + except Exception: + return "session-unknown" + + def execute( + self, sql: str, bindings: Optional[Sequence[Any]] = None + ) -> SessionCursorWrapper: + """Execute a SQL statement and return a cursor wrapper.""" + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed session handle") + + if self._cursor: + self._cursor.close() + + self._cursor = SessionCursorWrapper(self._spark) + return self._cursor.execute(sql, bindings) + + def list_schemas( + self, database: str, schema: Optional[str] = None + ) -> SessionCursorWrapper: + """List schemas in the given database/catalog.""" + if schema: + sql = f"SHOW SCHEMAS IN {database} LIKE '{schema}'" + else: + sql = f"SHOW SCHEMAS IN {database}" + return self.execute(sql) + + def list_tables(self, database: str, schema: str) -> SessionCursorWrapper: + """List tables in the given database and schema.""" + sql = f"SHOW TABLES IN {database}.{schema}" + return self.execute(sql) + + def cancel(self) -> None: + """Cancel any in-progress operations.""" + logger.debug("DatabricksSessionHandle.cancel() called") + if self._cursor: + self._cursor.cancel() + self.open = False + + def close(self) -> None: + """Close the session handle.""" + logger.debug("DatabricksSessionHandle.close() called") + if self._cursor: + self._cursor.close() + self.open = False + # Note: We don't stop the SparkSession as it may be shared + + def rollback(self) -> None: + """Required for interface compatibility, but not implemented.""" + logger.debug("NotImplemented: rollback (session mode)") + + def __del__(self) -> None: + if self._cursor: + self._cursor.close() + self.close() + + def __str__(self) -> str: + return f"SessionHandle(session-id={self.session_id})" diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py index 8cf9219e0..1a31b1cfb 100644 --- a/tests/unit/test_connection_manager.py +++ b/tests/unit/test_connection_manager.py @@ -69,6 +69,7 @@ def test_open_calls_is_cluster_http_path_for_warehouse( mock_connection.credentials.connect_retries = 1 mock_connection.credentials.connect_timeout = 10 mock_connection.credentials.query_tags = None + mock_connection.credentials.is_session_mode = False # Not session mode mock_connection.http_path = "sql/protocolv1/o/abc123def456" mock_connection.credentials.authenticate.return_value = Mock() mock_connection._query_header_context = None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..f1d947827 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,323 @@ +"""Unit tests for session mode components.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.databricks.session import ( + DatabricksSessionHandle, + SessionCursorWrapper, +) + + +class TestSessionCursorWrapper: + """Tests for SessionCursorWrapper.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + return spark + + @pytest.fixture + def cursor(self, mock_spark): + """Create a SessionCursorWrapper with mock SparkSession.""" + return SessionCursorWrapper(mock_spark) + + def test_execute_cleans_sql(self, cursor, mock_spark): + """Test that execute cleans SQL (strips whitespace and trailing semicolon).""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + result = cursor.execute(" SELECT 1; ") + + mock_spark.sql.assert_called_once_with("SELECT 1") + assert result is cursor + + def test_execute_with_bindings(self, cursor, mock_spark): + """Test that execute handles bindings via string substitution.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT %s, %s", (1, "test")) + + mock_spark.sql.assert_called_once_with("SELECT 1, test") + + def test_fetchall_returns_tuples(self, cursor, mock_spark): + """Test that fetchall returns list of tuples.""" + mock_df = MagicMock() + mock_row1 = MagicMock() + mock_row1.__iter__ = lambda self: iter([1, "a"]) + mock_row2 = MagicMock() + mock_row2.__iter__ = lambda self: iter([2, "b"]) + mock_df.collect.return_value = [mock_row1, mock_row2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + + assert result == [(1, "a"), (2, "b")] + + def test_fetchone_returns_single_tuple(self, cursor, mock_spark): + """Test that fetchone returns a single tuple.""" + mock_df = MagicMock() + mock_row = MagicMock() + mock_row.__iter__ = lambda self: iter([1, "a"]) + mock_df.collect.return_value = [mock_row] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result == (1, "a") + + def test_fetchone_returns_none_when_empty(self, cursor, mock_spark): + """Test that fetchone returns None when no rows.""" + mock_df = MagicMock() + mock_df.collect.return_value = [] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchone() + + assert result is None + + def test_fetchmany_returns_limited_rows(self, cursor, mock_spark): + """Test that fetchmany returns limited number of rows.""" + mock_df = MagicMock() + mock_rows = [MagicMock() for _ in range(5)] + for i, row in enumerate(mock_rows): + row.__iter__ = (lambda i: lambda self: iter([i]))(i) + mock_df.collect.return_value = mock_rows + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + result = cursor.fetchmany(2) + + assert len(result) == 2 + assert result == [(0,), (1,)] + + def test_description_returns_column_info(self, cursor, mock_spark): + """Test that description returns column metadata.""" + mock_df = MagicMock() + mock_field1 = MagicMock() + mock_field1.name = "id" + mock_field1.dataType.simpleString.return_value = "int" + mock_field1.nullable = False + + mock_field2 = MagicMock() + mock_field2.name = "name" + mock_field2.dataType.simpleString.return_value = "string" + mock_field2.nullable = True + + mock_df.schema.fields = [mock_field1, mock_field2] + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT * FROM test") + desc = cursor.description + + assert len(desc) == 2 + assert desc[0][0] == "id" + assert desc[0][1] == "int" + assert desc[0][6] is False + assert desc[1][0] == "name" + assert desc[1][1] == "string" + assert desc[1][6] is True + + def test_description_returns_none_before_execute(self, cursor): + """Test that description returns None before execute.""" + assert cursor.description is None + + def test_get_response_returns_adapter_response(self, cursor, mock_spark): + """Test that get_response returns an AdapterResponse.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor.execute("SELECT 1") + response = cursor.get_response() + + assert response._message == "OK" + assert response.query_id == "session-query" + + def test_close_sets_open_to_false(self, cursor): + """Test that close sets open to False.""" + assert cursor.open is True + cursor.close() + assert cursor.open is False + + def test_context_manager(self, mock_spark): + """Test that cursor works as context manager.""" + with SessionCursorWrapper(mock_spark) as cursor: + assert cursor.open is True + assert cursor.open is False + + +class TestDatabricksSessionHandle: + """Tests for DatabricksSessionHandle.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + spark = MagicMock() + spark.sparkContext.applicationId = "app-123" + spark.conf.get.return_value = "14.3.x-scala2.12" + return spark + + @pytest.fixture + def handle(self, mock_spark): + """Create a DatabricksSessionHandle with mock SparkSession.""" + return DatabricksSessionHandle(mock_spark) + + def test_session_id_returns_application_id(self, handle): + """Test that session_id returns the Spark application ID.""" + assert handle.session_id == "app-123" + + def test_dbr_version_extracts_version(self, handle, mock_spark): + """Test that dbr_version extracts version from Spark config.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + version = handle.dbr_version + + assert version == (14, 3) + + def test_dbr_version_caches_result(self, handle, mock_spark): + """Test that dbr_version caches the result.""" + mock_spark.conf.get.return_value = "14.3.x-scala2.12" + + _ = handle.dbr_version + _ = handle.dbr_version + + # Should only call conf.get once due to caching + assert mock_spark.conf.get.call_count == 1 + + def test_execute_returns_cursor(self, handle, mock_spark): + """Test that execute returns a SessionCursorWrapper.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor = handle.execute("SELECT 1") + + assert isinstance(cursor, SessionCursorWrapper) + + def test_execute_closes_previous_cursor(self, handle, mock_spark): + """Test that execute closes any previous cursor.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + cursor1 = handle.execute("SELECT 1") + assert cursor1.open is True + + cursor2 = handle.execute("SELECT 2") + assert cursor1.open is False + assert cursor2.open is True + + def test_close_sets_open_to_false(self, handle): + """Test that close sets open to False.""" + assert handle.open is True + handle.close() + assert handle.open is False + + def test_list_schemas_executes_show_schemas(self, handle, mock_spark): + """Test that list_schemas executes SHOW SCHEMAS.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog") + + def test_list_schemas_with_pattern(self, handle, mock_spark): + """Test that list_schemas includes LIKE pattern.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_schemas("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW SCHEMAS IN my_catalog LIKE 'my_schema'") + + def test_list_tables_executes_show_tables(self, handle, mock_spark): + """Test that list_tables executes SHOW TABLES.""" + mock_df = MagicMock() + mock_spark.sql.return_value = mock_df + + handle.list_tables("my_catalog", "my_schema") + + mock_spark.sql.assert_called_with("SHOW TABLES IN my_catalog.my_schema") + + def test_create_gets_or_creates_spark_session(self): + """Test that create uses getOrCreate.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + handle = DatabricksSessionHandle.create() + + mock_builder.getOrCreate.assert_called_once() + assert handle.session_id == "app-456" + + def test_create_sets_catalog(self): + """Test that create sets the catalog.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(catalog="my_catalog") + + mock_spark.catalog.setCurrentCatalog.assert_called_once_with("my_catalog") + + def test_create_sets_schema(self): + """Test that create sets the schema.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(schema="my_schema") + + mock_spark.catalog.setCurrentDatabase.assert_called_once_with("my_schema") + + def test_create_sets_session_properties(self): + """Test that create sets session properties.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-456" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + DatabricksSessionHandle.create(session_properties={"key1": "value1", "key2": 123}) + + mock_spark.conf.set.assert_any_call("key1", "value1") + mock_spark.conf.set.assert_any_call("key2", "123") diff --git a/tests/unit/test_session_credentials.py b/tests/unit/test_session_credentials.py new file mode 100644 index 000000000..3cc4432a2 --- /dev/null +++ b/tests/unit/test_session_credentials.py @@ -0,0 +1,236 @@ +"""Unit tests for session mode credentials.""" + +import os +from unittest.mock import patch + +import pytest +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError + +from dbt.adapters.databricks.credentials import ( + CONNECTION_METHOD_DBSQL, + CONNECTION_METHOD_SESSION, + DBT_DATABRICKS_SESSION_MODE_ENV, + DATABRICKS_RUNTIME_VERSION_ENV, + DatabricksCredentials, +) + + +class TestSessionModeAutoDetection: + """Tests for session mode auto-detection.""" + + def test_default_method_is_dbsql(self): + """Test that default method is dbsql when no env vars set.""" + with patch.dict(os.environ, {}, clear=True): + # Need to provide host/http_path for dbsql mode + creds = DatabricksCredentials( + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_explicit_session_method(self): + """Test that explicit method=session is respected.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_explicit_dbsql_method(self): + """Test that explicit method=dbsql is respected.""" + creds = DatabricksCredentials( + method="dbsql", + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_env_var_enables_session_mode(self): + """Test that DBT_DATABRICKS_SESSION_MODE=true enables session mode.""" + with patch.dict(os.environ, {DBT_DATABRICKS_SESSION_MODE_ENV: "true"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_env_var_case_insensitive(self): + """Test that DBT_DATABRICKS_SESSION_MODE is case insensitive.""" + with patch.dict(os.environ, {DBT_DATABRICKS_SESSION_MODE_ENV: "TRUE"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.is_session_mode is True + + def test_databricks_runtime_env_without_host_enables_session(self): + """Test that DATABRICKS_RUNTIME_VERSION without host enables session mode.""" + with patch.dict(os.environ, {DATABRICKS_RUNTIME_VERSION_ENV: "14.3.x-scala2.12"}): + creds = DatabricksCredentials(schema="test_schema") + assert creds.method == CONNECTION_METHOD_SESSION + assert creds.is_session_mode is True + + def test_databricks_runtime_env_with_host_uses_dbsql(self): + """Test that DATABRICKS_RUNTIME_VERSION with host uses dbsql mode.""" + with patch.dict(os.environ, {DATABRICKS_RUNTIME_VERSION_ENV: "14.3.x-scala2.12"}): + creds = DatabricksCredentials( + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + token="token", + schema="test_schema", + ) + assert creds.method == CONNECTION_METHOD_DBSQL + assert creds.is_session_mode is False + + def test_invalid_method_raises_error(self): + """Test that invalid method raises DbtValidationError.""" + with pytest.raises(DbtValidationError) as exc_info: + DatabricksCredentials( + method="invalid", + schema="test_schema", + ) + assert "Invalid connection method" in str(exc_info.value) + + +class TestSessionModeValidation: + """Tests for session mode validation.""" + + def test_session_mode_requires_schema(self): + """Test that session mode requires schema.""" + with patch( + "dbt.adapters.databricks.credentials.DatabricksCredentials._validate_session_mode" + ) as mock_validate: + mock_validate.side_effect = DbtValidationError("Schema is required for session mode.") + with pytest.raises(DbtValidationError) as exc_info: + creds = DatabricksCredentials(method="session") + creds.validate_creds() + assert "Schema is required" in str(exc_info.value) + + def test_session_mode_does_not_require_host(self): + """Test that session mode does not require host.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - host is not required for session mode + assert creds.host is None + + def test_session_mode_does_not_require_http_path(self): + """Test that session mode does not require http_path.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - http_path is not required for session mode + assert creds.http_path is None + + def test_session_mode_does_not_require_token(self): + """Test that session mode does not require token.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Should not raise - token is not required for session mode + assert creds.token is None + + +class TestSessionModeCredentialsManager: + """Tests for credentials manager in session mode.""" + + def test_session_mode_does_not_create_credentials_manager(self): + """Test that session mode does not create credentials manager.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + assert creds._credentials_manager is None + + def test_session_mode_authenticate_returns_none(self): + """Test that authenticate returns None in session mode.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + ) + # Mock _validate_session_mode to avoid pyspark import + with patch.object(creds, "_validate_session_mode"): + result = creds.authenticate() + assert result is None + + +class TestSessionModeConnectionKeys: + """Tests for connection keys in session mode.""" + + def test_session_mode_connection_keys(self): + """Test that session mode has correct connection keys.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + database="main", + ) + keys = creds._connection_keys() + assert "method" in keys + assert "schema" in keys + assert "catalog" in keys + assert "host" not in keys + assert "http_path" not in keys + + def test_session_mode_unique_field(self): + """Test that session mode unique_field is based on catalog/schema.""" + creds = DatabricksCredentials( + method="session", + schema="test_schema", + database="main", + ) + assert creds.unique_field == "session://main/test_schema" + + +class TestDbsqlModeValidation: + """Tests for DBSQL mode validation (existing behavior).""" + + def test_dbsql_mode_requires_host(self): + """Test that DBSQL mode requires host.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + http_path="/sql/1.0/warehouses/abc", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "host" in str(exc_info.value) + + def test_dbsql_mode_requires_http_path(self): + """Test that DBSQL mode requires http_path.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + host="my.databricks.com", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "http_path" in str(exc_info.value) + + def test_dbsql_mode_requires_token_or_oauth(self): + """Test that DBSQL mode requires token or oauth.""" + # Create credentials with session mode first (no SDK auth), then switch to dbsql + with patch.dict(os.environ, {}, clear=True): + creds = DatabricksCredentials( + method="session", # Start with session to avoid SDK auth + host="my.databricks.com", + http_path="/sql/1.0/warehouses/abc", + schema="test_schema", + ) + # Switch to dbsql mode for validation test + creds.method = "dbsql" + with pytest.raises(DbtConfigError) as exc_info: + creds.validate_creds() + assert "oauth" in str(exc_info.value).lower() or "token" in str(exc_info.value).lower() diff --git a/tests/unit/test_session_python.py b/tests/unit/test_session_python.py new file mode 100644 index 000000000..0544ba449 --- /dev/null +++ b/tests/unit/test_session_python.py @@ -0,0 +1,189 @@ +"""Unit tests for session mode Python model submission.""" + +from unittest.mock import MagicMock, patch + +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.python_models.python_submissions import ( + SessionPythonJobHelper, + SessionPythonSubmitter, + SessionStateManager, +) + + +class TestSessionStateManager: + """Tests for SessionStateManager.""" + + def test_get_clean_exec_globals_includes_spark(self): + """Test that get_clean_exec_globals includes spark.""" + mock_spark = MagicMock() + + globals_dict = SessionStateManager.get_clean_exec_globals(mock_spark) + + assert globals_dict["spark"] is mock_spark + + def test_get_clean_exec_globals_includes_dbt(self): + """Test that get_clean_exec_globals includes dbt module.""" + mock_spark = MagicMock() + + globals_dict = SessionStateManager.get_clean_exec_globals(mock_spark) + + assert "dbt" in globals_dict + + def test_cleanup_temp_views_drops_temp_views(self): + """Test that cleanup_temp_views drops temporary views.""" + mock_spark = MagicMock() + mock_row = MagicMock() + mock_row.viewName = "temp_view_1" + mock_row.isTemporary = True + mock_spark.sql.return_value.collect.return_value = [mock_row] + + SessionStateManager.cleanup_temp_views(mock_spark) + + mock_spark.catalog.dropTempView.assert_called_once_with("temp_view_1") + + def test_cleanup_temp_views_ignores_non_temp_views(self): + """Test that cleanup_temp_views ignores non-temporary views.""" + mock_spark = MagicMock() + mock_row = MagicMock() + mock_row.viewName = "permanent_view" + mock_row.isTemporary = False + mock_spark.sql.return_value.collect.return_value = [mock_row] + + SessionStateManager.cleanup_temp_views(mock_spark) + + mock_spark.catalog.dropTempView.assert_not_called() + + +class TestSessionPythonSubmitter: + """Tests for SessionPythonSubmitter.""" + + @pytest.fixture + def mock_spark(self): + """Create a mock SparkSession.""" + return MagicMock() + + @pytest.fixture + def submitter(self, mock_spark): + """Create a SessionPythonSubmitter with mock SparkSession.""" + return SessionPythonSubmitter(mock_spark) + + def test_submit_executes_code(self, submitter, mock_spark): + """Test that submit executes the compiled code.""" + # Simple code that just sets a variable + compiled_code = "result = 1 + 1" + + submitter.submit(compiled_code) + + # The code should execute without error + + def test_submit_provides_spark_in_globals(self, submitter, mock_spark): + """Test that submit provides spark in the execution globals.""" + # Code that uses spark + compiled_code = "spark_app_name = spark.sparkContext.appName" + + mock_spark.sparkContext.appName = "test-app" + + submitter.submit(compiled_code) + + # The code should execute without error + + def test_submit_raises_on_execution_error(self, submitter, mock_spark): + """Test that submit raises DbtRuntimeError on execution error.""" + compiled_code = "raise ValueError('test error')" + + with pytest.raises(DbtRuntimeError) as exc_info: + submitter.submit(compiled_code) + + assert "Python model execution failed" in str(exc_info.value) + assert "test error" in str(exc_info.value) + + def test_submit_cleans_up_temp_views(self, submitter, mock_spark): + """Test that submit cleans up temp views after execution.""" + compiled_code = "result = 1" + mock_spark.sql.return_value.collect.return_value = [] + + submitter.submit(compiled_code) + + # Cleanup should be called (SHOW VIEWS) + mock_spark.sql.assert_called() + + +class TestSessionPythonJobHelper: + """Tests for SessionPythonJobHelper.""" + + @pytest.fixture + def mock_credentials(self): + """Create mock credentials.""" + creds = MagicMock() + creds.is_session_mode = True + return creds + + @pytest.fixture + def parsed_model_dict(self): + """Create a parsed model dictionary.""" + return { + "catalog": "main", + "schema": "test_schema", + "identifier": "test_model", + "config": { + "timeout": 3600, + "packages": [], + "index_url": None, + "additional_libs": [], + "python_job_config": {}, # Empty dict instead of None + "cluster_id": None, + "http_path": None, + "create_notebook": False, + "job_cluster_config": {}, # Empty dict instead of None + "access_control_list": [], + "notebook_access_control_list": [], + "user_folder_for_python": True, + "environment_key": None, + "environment_dependencies": [], + }, + } + + def test_init_gets_spark_session(self, mock_credentials, parsed_model_dict): + """Test that __init__ gets the SparkSession.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-123" + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + helper = SessionPythonJobHelper(parsed_model_dict, mock_credentials) + + mock_builder.getOrCreate.assert_called_once() + assert helper._spark is mock_spark + + def test_submit_delegates_to_submitter(self, mock_credentials, parsed_model_dict): + """Test that submit delegates to the submitter.""" + mock_spark = MagicMock() + mock_spark.sparkContext.applicationId = "app-123" + mock_spark.sql.return_value.collect.return_value = [] + mock_builder = MagicMock() + mock_builder.getOrCreate.return_value = mock_spark + + with patch.dict( + "sys.modules", + {"pyspark": MagicMock(), "pyspark.sql": MagicMock()}, + ): + import sys + + sys.modules["pyspark.sql"].SparkSession.builder = mock_builder + + helper = SessionPythonJobHelper(parsed_model_dict, mock_credentials) + + compiled_code = "result = 1" + helper.submit(compiled_code) + + # The code should execute without error