diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 51e9dd0..634f0d0 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -142,7 +142,11 @@ class TestCursor(BaseCursorTestSuit): def sync_cursor( self, session_pool_sync: ydb.QuerySessionPool ) -> Generator[Cursor]: - cursor = Cursor(session_pool_sync, ydb.QuerySerializableReadWrite()) + cursor = Cursor( + session_pool_sync, + ydb.QuerySerializableReadWrite(), + request_settings=ydb.BaseRequestSettings(), + ) yield cursor cursor.close() @@ -177,7 +181,11 @@ class TestAsyncCursor(BaseCursorTestSuit): async def async_cursor( self, session_pool: ydb.aio.QuerySessionPool ) -> AsyncGenerator[Cursor]: - cursor = AsyncCursor(session_pool, ydb.QuerySerializableReadWrite()) + cursor = AsyncCursor( + session_pool, + ydb.QuerySerializableReadWrite(), + request_settings=ydb.BaseRequestSettings(), + ) yield cursor await greenlet_spawn(cursor.close) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index d7748b0..b99874d 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -18,6 +18,7 @@ from .errors import InternalError from .errors import NotSupportedError from .utils import handle_ydb_errors +from .utils import maybe_get_current_trace_id class IsolationLevel(str, Enum): @@ -101,6 +102,9 @@ def __init__( self._session_pool = self._pool_cls(self._driver, size=5) self._session: ydb.QuerySession | ydb.aio.QuerySession | None = None + self.request_settings: ydb.BaseRequestSettings = ( + ydb.BaseRequestSettings() + ) def set_isolation_level(self, isolation_level: IsolationLevel) -> None: if self._tx_context and self._tx_context.tx_id: @@ -129,6 +133,20 @@ def get_isolation_level(self) -> str: msg = f"{self._tx_mode.name} is not supported" raise NotSupportedError(msg) + def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None: + self.request_settings = value + + def get_ydb_request_settings(self) -> ydb.BaseRequestSettings: + return self.request_settings + + def _get_request_settings(self) -> ydb.BaseRequestSettings: + settings = self.request_settings.make_copy() + + if self.request_settings.trace_id is None: + settings = settings.with_trace_id(maybe_get_current_trace_id()) + + return settings + def _get_client_settings(self) -> ydb.QueryClientSettings: return ( ydb.QueryClientSettings() @@ -172,6 +190,7 @@ def cursor(self) -> Cursor: tx_mode=self._tx_mode, tx_context=self._tx_context, table_path_prefix=self.table_path_prefix, + request_settings=self.request_settings, ) def wait_ready(self, timeout: int = 10) -> None: @@ -197,7 +216,8 @@ def begin(self) -> None: @handle_ydb_errors def commit(self) -> None: if self._tx_context and self._tx_context.tx_id: - self._tx_context.commit() + settings = self._get_request_settings() + self._tx_context.commit(settings=settings) self._session_pool.release(self._session) self._tx_context = None self._session = None @@ -205,7 +225,8 @@ def commit(self) -> None: @handle_ydb_errors def rollback(self) -> None: if self._tx_context and self._tx_context.tx_id: - self._tx_context.rollback() + settings = self._get_request_settings() + self._tx_context.rollback(settings=settings) self._session_pool.release(self._session) self._tx_context = None self._session = None @@ -223,10 +244,15 @@ def close(self) -> None: @handle_ydb_errors def describe(self, table_path: str) -> ydb.TableSchemeEntry: + settings = self._get_request_settings() + abs_table_path = posixpath.join( self.database, self.table_path_prefix, table_path ) - return self._driver.table_client.describe_table(abs_table_path) + return self._driver.table_client.describe_table( + abs_table_path, + settings=settings, + ) @handle_ydb_errors def check_exists(self, table_path: str) -> bool: @@ -243,9 +269,12 @@ def get_table_names(self) -> list[str]: def _check_path_exists(self, table_path: str) -> bool: try: + settings = self._get_request_settings() def callee() -> None: - self._driver.scheme_client.describe_path(table_path) + self._driver.scheme_client.describe_path( + table_path, settings=settings + ) retry_operation_sync(callee) except ydb.SchemeError: @@ -254,8 +283,13 @@ def callee() -> None: return True def _get_table_names(self, abs_dir_path: str) -> list[str]: + settings = self._get_request_settings() + def callee() -> ydb.Directory: - return self._driver.scheme_client.list_directory(abs_dir_path) + return self._driver.scheme_client.list_directory( + abs_dir_path, + settings=settings, + ) directory = retry_operation_sync(callee) result = [] @@ -300,6 +334,7 @@ def cursor(self) -> AsyncCursor: tx_mode=self._tx_mode, tx_context=self._tx_context, table_path_prefix=self.table_path_prefix, + request_settings=self.request_settings, ) async def wait_ready(self, timeout: int = 10) -> None: @@ -325,7 +360,8 @@ async def begin(self) -> None: @handle_ydb_errors async def commit(self) -> None: if self._session and self._tx_context and self._tx_context.tx_id: - await self._tx_context.commit() + settings = self._get_request_settings() + await self._tx_context.commit(settings=settings) await self._session_pool.release(self._session) self._session = None self._tx_context = None @@ -333,7 +369,8 @@ async def commit(self) -> None: @handle_ydb_errors async def rollback(self) -> None: if self._session and self._tx_context and self._tx_context.tx_id: - await self._tx_context.rollback() + settings = self._get_request_settings() + await self._tx_context.rollback(settings=settings) await self._session_pool.release(self._session) self._session = None self._tx_context = None @@ -351,10 +388,15 @@ async def close(self) -> None: @handle_ydb_errors async def describe(self, table_path: str) -> ydb.TableSchemeEntry: + settings = self._get_request_settings() + abs_table_path = posixpath.join( self.database, self.table_path_prefix, table_path ) - return await self._driver.table_client.describe_table(abs_table_path) + return await self._driver.table_client.describe_table( + abs_table_path, + settings=settings, + ) @handle_ydb_errors async def check_exists(self, table_path: str) -> bool: @@ -371,9 +413,13 @@ async def get_table_names(self) -> list[str]: async def _check_path_exists(self, table_path: str) -> bool: try: + settings = self._get_request_settings() async def callee() -> None: - await self._driver.scheme_client.describe_path(table_path) + await self._driver.scheme_client.describe_path( + table_path, + settings=settings, + ) await retry_operation_async(callee) except ydb.SchemeError: @@ -382,9 +428,12 @@ async def callee() -> None: return True async def _get_table_names(self, abs_dir_path: str) -> list[str]: + settings = self._get_request_settings() + async def callee() -> ydb.Directory: return await self._driver.scheme_client.list_directory( - abs_dir_path + abs_dir_path, + settings=settings, ) directory = await retry_operation_async(callee) diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 2d89c50..c10ca80 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -16,6 +16,7 @@ from .errors import ProgrammingError from .utils import CursorStatus from .utils import handle_ydb_errors +from .utils import maybe_get_current_trace_id ParametersType = dict[ str, @@ -148,12 +149,14 @@ def __init__( self, session_pool: ydb.QuerySessionPool, tx_mode: ydb.BaseQueryTxMode, + request_settings: ydb.BaseRequestSettings, tx_context: ydb.QueryTxContext | None = None, table_path_prefix: str = "", ) -> None: super().__init__() self._session_pool = session_pool self._tx_mode = tx_mode + self._request_settings = request_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix @@ -169,16 +172,27 @@ def fetchmany(self, size: int | None = None) -> list: def fetchall(self) -> list: return self._fetchall_from_buffer() + def _get_request_settings(self) -> ydb.BaseRequestSettings: + settings = self._request_settings.make_copy() + + if self._request_settings.trace_id is None: + settings = settings.with_trace_id(maybe_get_current_trace_id()) + + return settings + @handle_ydb_errors def _execute_generic_query( self, query: str, parameters: ParametersType | None = None ) -> Iterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() + def callee( session: ydb.QuerySession, ) -> Iterator[ydb.convert.ResultSet]: return session.execute( query=query, parameters=parameters, + settings=settings, ) return self._session_pool.retry_operation_sync(callee) @@ -189,6 +203,8 @@ def _execute_session_query( query: str, parameters: ParametersType | None = None, ) -> Iterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() + def callee( session: ydb.QuerySession, ) -> Iterator[ydb.convert.ResultSet]: @@ -196,6 +212,7 @@ def callee( query=query, parameters=parameters, commit_tx=True, + settings=settings, ) return self._session_pool.retry_operation_sync(callee) @@ -207,10 +224,12 @@ def _execute_transactional_query( query: str, parameters: ParametersType | None = None, ) -> Iterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() return tx_context.execute( query=query, parameters=parameters, commit_tx=False, + settings=settings, ) def execute_scheme( @@ -304,12 +323,14 @@ def __init__( self, session_pool: ydb.aio.QuerySessionPool, tx_mode: ydb.BaseQueryTxMode, + request_settings: ydb.BaseRequestSettings, tx_context: ydb.aio.QueryTxContext | None = None, table_path_prefix: str = "", ) -> None: super().__init__() self._session_pool = session_pool self._tx_mode = tx_mode + self._request_settings = request_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix @@ -325,16 +346,27 @@ async def fetchmany(self, size: int | None = None) -> list: async def fetchall(self) -> list: return self._fetchall_from_buffer() + def _get_request_settings(self) -> ydb.BaseRequestSettings: + settings = self._request_settings.make_copy() + + if self._request_settings.trace_id is None: + settings = settings.with_trace_id(maybe_get_current_trace_id()) + + return settings + @handle_ydb_errors async def _execute_generic_query( self, query: str, parameters: ParametersType | None = None ) -> AsyncIterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() + async def callee( session: ydb.aio.QuerySession, ) -> AsyncIterator[ydb.convert.ResultSet]: return await session.execute( query=query, parameters=parameters, + settings=settings, ) return await self._session_pool.retry_operation_async(callee) @@ -345,6 +377,8 @@ async def _execute_session_query( query: str, parameters: ParametersType | None = None, ) -> AsyncIterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() + async def callee( session: ydb.aio.QuerySession, ) -> AsyncIterator[ydb.convert.ResultSet]: @@ -352,6 +386,7 @@ async def callee( query=query, parameters=parameters, commit_tx=True, + settings=settings, ) return await self._session_pool.retry_operation_async(callee) @@ -363,10 +398,12 @@ async def _execute_transactional_query( query: str, parameters: ParametersType | None = None, ) -> AsyncIterator[ydb.convert.ResultSet]: + settings = self._get_request_settings() return await tx_context.execute( query=query, parameters=parameters, commit_tx=False, + settings=settings, ) async def execute_scheme( diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index b69b58a..38f964f 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import functools +import importlib.util from enum import Enum from inspect import iscoroutinefunction from typing import Any @@ -100,3 +103,17 @@ class CursorStatus(str, Enum): running = "running" finished = "finished" closed = "closed" + + +def maybe_get_current_trace_id() -> str | None: + # Check if OpenTelemetry is available + if importlib.util.find_spec("opentelemetry"): + from opentelemetry import trace # type: ignore + + current_span = trace.get_current_span() + + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + # Return None if OpenTelemetry is not available or trace ID is invalid + return None