Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
556 changes: 556 additions & 0 deletions PLAN_v2_session_mode.md

Large diffs are not rendered by default.

100 changes: 88 additions & 12 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from dbt.adapters.databricks.handle import CursorWrapper, DatabricksHandle, SqlUtils
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.session import DatabricksSessionHandle, SessionCursorWrapper
from dbt.adapters.databricks.utils import QueryTagsUtils, is_cluster_http_path, redact_credentials

if TYPE_CHECKING:
Expand Down Expand Up @@ -150,6 +151,8 @@ class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_manager: Optional[DatabricksCredentialManager] = None
_dbr_capabilities_cache: dict[str, DBRCapabilities] = {}
# Cache for session mode capabilities (keyed by "session")
_session_capabilities: Optional[DBRCapabilities] = None

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
super().__init__(profile, mp_context)
Expand All @@ -159,9 +162,19 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
def api_client(self) -> DatabricksApiClient:
if self._api_client is None:
credentials = cast(DatabricksCredentials, self.profile.credentials)
if credentials.is_session_mode:
raise DbtRuntimeError(
"API client is not available in session mode. "
"Session mode does not support API-based operations."
)
self._api_client = DatabricksApiClient(credentials, 15 * 60)
return self._api_client

def is_session_mode(self) -> bool:
"""Check if the connection is using session mode."""
credentials = cast(DatabricksCredentials, self.profile.credentials)
return credentials.is_session_mode

def is_cluster(self) -> bool:
conn = self.get_thread_connection()
databricks_conn = cast(DatabricksDBTConnection, conn)
Expand Down Expand Up @@ -210,14 +223,16 @@ def _cache_dbr_capabilities(cls, creds: DatabricksCredentials, http_path: str) -

def cancel_open(self) -> list[str]:
cancelled = super().cancel_open()
logger.info("Cancelling open python jobs")
PythonRunTracker.cancel_runs(self.api_client)
# Only cancel Python jobs via API if not in session mode
if not self.is_session_mode():
logger.info("Cancelling open python jobs")
PythonRunTracker.cancel_runs(self.api_client)
return cancelled

def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)

handle: DatabricksHandle = self.get_thread_connection().handle
handle: DatabricksHandle | DatabricksSessionHandle = self.get_thread_connection().handle
dbr_version = handle.dbr_version
return (dbr_version > version) - (dbr_version < version)

Expand Down Expand Up @@ -321,7 +336,7 @@ def add_query(
fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name)))

with self.exception_handler(sql):
cursor: Optional[CursorWrapper] = None
cursor: Optional[CursorWrapper | SessionCursorWrapper] = None
try:
log_sql = redact_credentials(sql)
if abridge_sql_log:
Expand All @@ -337,7 +352,7 @@ def add_query(

pre = time.time()

handle: DatabricksHandle = connection.handle
handle: DatabricksHandle | DatabricksSessionHandle = connection.handle
cursor = handle.execute(sql, bindings)
response = self.get_response(cursor)
fire_event(
Expand Down Expand Up @@ -380,14 +395,18 @@ def execute(
cursor.close()

def _execute_with_cursor(
self, log_sql: str, f: Callable[[DatabricksHandle], CursorWrapper]
self,
log_sql: str,
f: Callable[
[DatabricksHandle | DatabricksSessionHandle], CursorWrapper | SessionCursorWrapper
],
) -> "Table":
connection = self.get_thread_connection()

fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=cast_to_str(connection.name)))

with self.exception_handler(log_sql):
cursor: Optional[CursorWrapper] = None
cursor: Optional[CursorWrapper | SessionCursorWrapper] = None
try:
fire_event(
SQLQuery(
Expand All @@ -399,7 +418,7 @@ def _execute_with_cursor(

pre = time.time()

handle: DatabricksHandle = connection.handle
handle: DatabricksHandle | DatabricksSessionHandle = connection.handle
cursor = f(handle)

response = self.get_response(cursor)
Expand Down Expand Up @@ -464,9 +483,66 @@ def open(cls, connection: Connection) -> Connection:
return connection

creds: DatabricksCredentials = connection.credentials

# Dispatch based on connection method
if creds.is_session_mode:
return cls._open_session(databricks_connection, creds)
else:
return cls._open_dbsql(databricks_connection, creds)

@classmethod
def _open_session(
cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials
) -> Connection:
"""Open a connection using SparkSession mode."""
logger.debug("Opening connection in session mode")

def connect() -> DatabricksSessionHandle:
try:
handle = DatabricksSessionHandle.create(
catalog=creds.database,
schema=creds.schema,
session_properties=creds.session_properties,
)
databricks_connection.session_id = handle.session_id

# Cache capabilities for session mode
cls._cache_session_capabilities(handle)
databricks_connection.capabilities = cls._session_capabilities or DBRCapabilities()

logger.debug(f"Session mode connection opened: {handle}")
return handle
except Exception as exc:
logger.error(ConnectionCreateError(exc))
raise DbtDatabaseError(f"Failed to create session connection: {exc}") from exc

# Session mode doesn't need retry logic as SparkSession is already available
databricks_connection.handle = connect()
databricks_connection.state = ConnectionState.OPEN
return databricks_connection

@classmethod
def _cache_session_capabilities(cls, handle: DatabricksSessionHandle) -> None:
"""Cache DBR capabilities for session mode."""
if cls._session_capabilities is None:
dbr_version = handle.dbr_version
cls._session_capabilities = DBRCapabilities(
dbr_version=dbr_version,
is_sql_warehouse=False, # Session mode is always on a cluster
)
logger.debug(f"Cached session capabilities: DBR version {dbr_version}")

@classmethod
def _open_dbsql(
cls, databricks_connection: DatabricksDBTConnection, creds: DatabricksCredentials
) -> Connection:
"""Open a connection using DBSQL connector."""
timeout = creds.connect_timeout

cls.credentials_manager = creds.authenticate()
credentials_manager = creds.authenticate()
# In DBSQL mode, authenticate() always returns a credentials manager
assert credentials_manager is not None, "Credentials manager is required for DBSQL mode"
cls.credentials_manager = credentials_manager

# Get merged query tags if we have query header context
query_header_context = getattr(databricks_connection, "_query_header_context", None)
Expand All @@ -475,7 +551,7 @@ def open(cls, connection: Connection) -> Connection:
merged_query_tags = QueryConfigUtils.get_merged_query_tags(query_header_context, creds)

conn_args = SqlUtils.prepare_connection_arguments(
creds, cls.credentials_manager, databricks_connection.http_path, merged_query_tags
creds, credentials_manager, databricks_connection.http_path, merged_query_tags
)

def connect() -> DatabricksHandle:
Expand Down Expand Up @@ -507,7 +583,7 @@ def exponential_backoff(attempt: int) -> int:
retryable_exceptions = [Error]

return cls.retry_connection(
connection,
databricks_connection,
connect=connect,
logger=logger,
retryable_exceptions=retryable_exceptions,
Expand All @@ -527,7 +603,7 @@ def close(cls, connection: Connection) -> Connection:

@classmethod
def get_response(cls, cursor: Any) -> AdapterResponse:
if isinstance(cursor, CursorWrapper):
if isinstance(cursor, (CursorWrapper, SessionCursorWrapper)):
return cursor.get_response()
else:
return AdapterResponse("OK")
Expand Down
93 changes: 90 additions & 3 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import itertools
import json
import os
import re
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, cast

from dbt.adapters.contracts.connection import Credentials
from dbt_common.exceptions import DbtConfigError, DbtValidationError
from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtValidationError
from mashumaro import DataClassDictMixin
from requests import PreparedRequest
from requests.auth import AuthBase
Expand All @@ -16,6 +17,14 @@
from dbt.adapters.databricks.global_state import GlobalState
from dbt.adapters.databricks.logging import logger

# Connection method constants
CONNECTION_METHOD_SESSION = "session"
CONNECTION_METHOD_DBSQL = "dbsql"

# Environment variable for session mode
DBT_DATABRICKS_SESSION_MODE_ENV = "DBT_DATABRICKS_SESSION_MODE"
DATABRICKS_RUNTIME_VERSION_ENV = "DATABRICKS_RUNTIME_VERSION"

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
Expand Down Expand Up @@ -48,6 +57,9 @@ class DatabricksCredentials(Credentials):
connection_parameters: Optional[dict[str, Any]] = None
auth_type: Optional[str] = None

# Connection method: "session" for SparkSession mode, "dbsql" for DBSQL connector (default)
method: Optional[str] = None

# Named compute resources specified in the profile. Used for
# creating a connection when a model specifies a compute resource.
compute: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -102,6 +114,9 @@ def __post_init__(self) -> None:
else:
self.database = "hive_metastore"

# Auto-detect and validate connection method
self._init_connection_method()

connection_parameters = self.connection_parameters or {}
for key in (
"server_hostname",
Expand Down Expand Up @@ -130,9 +145,57 @@ def __post_init__(self) -> None:
if "_socket_timeout" not in connection_parameters:
connection_parameters["_socket_timeout"] = 600
self.connection_parameters = connection_parameters
self._credentials_manager = DatabricksCredentialManager.create_from(self)

# Only create credentials manager for non-session mode
if not self.is_session_mode:
self._credentials_manager = DatabricksCredentialManager.create_from(self)

def _init_connection_method(self) -> None:
"""Initialize and validate the connection method."""
if self.method is None:
# Auto-detect session mode
if os.getenv(DBT_DATABRICKS_SESSION_MODE_ENV, "").lower() == "true":
self.method = CONNECTION_METHOD_SESSION
elif os.getenv(DATABRICKS_RUNTIME_VERSION_ENV) and not self.host:
# Running on Databricks cluster without host configured
self.method = CONNECTION_METHOD_SESSION
else:
self.method = CONNECTION_METHOD_DBSQL

# Validate method value
if self.method not in (CONNECTION_METHOD_SESSION, CONNECTION_METHOD_DBSQL):
raise DbtValidationError(
f"Invalid connection method: '{self.method}'. "
f"Must be '{CONNECTION_METHOD_SESSION}' or '{CONNECTION_METHOD_DBSQL}'."
)

@property
def is_session_mode(self) -> bool:
"""Check if using session mode (SparkSession) for connections."""
return self.method == CONNECTION_METHOD_SESSION

def _validate_session_mode(self) -> None:
"""Validate configuration for session mode."""
try:
from pyspark.sql import SparkSession # noqa: F401
except ImportError:
raise DbtRuntimeError(
"Session mode requires PySpark. "
"Please ensure you are running on a Databricks cluster with PySpark available."
)

if self.schema is None:
raise DbtValidationError("Schema is required for session mode.")

def validate_creds(self) -> None:
"""Validate credentials based on connection method."""
if self.is_session_mode:
self._validate_session_mode()
else:
self._validate_dbsql_creds()

def _validate_dbsql_creds(self) -> None:
"""Validate credentials for DBSQL connector mode."""
for key in ["host", "http_path"]:
if not getattr(self, key):
raise DbtConfigError(f"The config '{key}' is required to connect to Databricks")
Expand Down Expand Up @@ -197,6 +260,9 @@ def type(self) -> str:

@property
def unique_field(self) -> str:
if self.is_session_mode:
# For session mode, use a unique identifier based on catalog and schema
return f"session://{self.database}/{self.schema}"
return cast(str, self.host)

def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str, Any]]:
Expand All @@ -209,6 +275,15 @@ def connection_info(self, *, with_aliases: bool = False) -> Iterable[tuple[str,
if key in as_dict:
yield key, as_dict[key]

def _connection_keys_session(self) -> tuple[str, ...]:
"""Connection keys for session mode."""
connection_keys = ["method", "schema"]
if self.database:
connection_keys.insert(1, "catalog")
if self.session_properties:
connection_keys.append("session_properties")
return tuple(connection_keys)

def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]:
# Assuming `DatabricksCredentials.connection_info(self, *, with_aliases: bool = False)`
# is called from only:
Expand All @@ -218,6 +293,11 @@ def _connection_keys(self, *, with_aliases: bool = False) -> tuple[str, ...]:
#
# Thus, if `with_aliases` is `True`, `DatabricksCredentials._connection_keys` should return
# the internal key names; otherwise it can use aliases to show in `dbt debug`.

# Session mode has different connection keys
if self.is_session_mode:
return self._connection_keys_session()

connection_keys = ["host", "http_path", "schema"]
if with_aliases:
connection_keys.insert(2, "database")
Expand All @@ -239,8 +319,15 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]:
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]

def authenticate(self) -> "DatabricksCredentialManager":
def authenticate(self) -> Optional["DatabricksCredentialManager"]:
"""Authenticate and return credentials manager.

For session mode, returns None as no external authentication is needed.
For DBSQL mode, validates credentials and returns the credentials manager.
"""
self.validate_creds()
if self.is_session_mode:
return None
assert self._credentials_manager is not None, "Credentials manager is not set."
return self._credentials_manager

Expand Down
Loading