Skip to content

Cherry pick settings propagation #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ class TestCursor(BaseCursorTestSuit):
def sync_cursor(
self, session_pool_sync: ydb.QuerySessionPool
) -> Generator[Cursor]:
cursor = Cursor(session_pool_sync, ydb.QuerySerializableReadWrite())
cursor = Cursor(
session_pool_sync,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
)
yield cursor
cursor.close()

Expand Down Expand Up @@ -177,7 +181,11 @@ class TestAsyncCursor(BaseCursorTestSuit):
async def async_cursor(
self, session_pool: ydb.aio.QuerySessionPool
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(session_pool, ydb.QuerySerializableReadWrite())
cursor = AsyncCursor(
session_pool,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
)
yield cursor
await greenlet_spawn(cursor.close)

Expand Down
69 changes: 59 additions & 10 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .errors import InternalError
from .errors import NotSupportedError
from .utils import handle_ydb_errors
from .utils import maybe_get_current_trace_id


class IsolationLevel(str, Enum):
Expand Down Expand Up @@ -101,6 +102,9 @@ def __init__(
self._session_pool = self._pool_cls(self._driver, size=5)

self._session: ydb.QuerySession | ydb.aio.QuerySession | None = None
self.request_settings: ydb.BaseRequestSettings = (
ydb.BaseRequestSettings()
)

def set_isolation_level(self, isolation_level: IsolationLevel) -> None:
if self._tx_context and self._tx_context.tx_id:
Expand Down Expand Up @@ -129,6 +133,20 @@ def get_isolation_level(self) -> str:
msg = f"{self._tx_mode.name} is not supported"
raise NotSupportedError(msg)

def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None:
self.request_settings = value

def get_ydb_request_settings(self) -> ydb.BaseRequestSettings:
return self.request_settings

def _get_request_settings(self) -> ydb.BaseRequestSettings:
settings = self.request_settings.make_copy()

if self.request_settings.trace_id is None:
settings = settings.with_trace_id(maybe_get_current_trace_id())

return settings

def _get_client_settings(self) -> ydb.QueryClientSettings:
return (
ydb.QueryClientSettings()
Expand Down Expand Up @@ -172,6 +190,7 @@ def cursor(self) -> Cursor:
tx_mode=self._tx_mode,
tx_context=self._tx_context,
table_path_prefix=self.table_path_prefix,
request_settings=self.request_settings,
)

def wait_ready(self, timeout: int = 10) -> None:
Expand All @@ -197,15 +216,17 @@ def begin(self) -> None:
@handle_ydb_errors
def commit(self) -> None:
if self._tx_context and self._tx_context.tx_id:
self._tx_context.commit()
settings = self._get_request_settings()
self._tx_context.commit(settings=settings)
self._session_pool.release(self._session)
self._tx_context = None
self._session = None

@handle_ydb_errors
def rollback(self) -> None:
if self._tx_context and self._tx_context.tx_id:
self._tx_context.rollback()
settings = self._get_request_settings()
self._tx_context.rollback(settings=settings)
self._session_pool.release(self._session)
self._tx_context = None
self._session = None
Expand All @@ -223,10 +244,15 @@ def close(self) -> None:

@handle_ydb_errors
def describe(self, table_path: str) -> ydb.TableSchemeEntry:
settings = self._get_request_settings()

abs_table_path = posixpath.join(
self.database, self.table_path_prefix, table_path
)
return self._driver.table_client.describe_table(abs_table_path)
return self._driver.table_client.describe_table(
abs_table_path,
settings=settings,
)

@handle_ydb_errors
def check_exists(self, table_path: str) -> bool:
Expand All @@ -243,9 +269,12 @@ def get_table_names(self) -> list[str]:

def _check_path_exists(self, table_path: str) -> bool:
try:
settings = self._get_request_settings()

def callee() -> None:
self._driver.scheme_client.describe_path(table_path)
self._driver.scheme_client.describe_path(
table_path, settings=settings
)

retry_operation_sync(callee)
except ydb.SchemeError:
Expand All @@ -254,8 +283,13 @@ def callee() -> None:
return True

def _get_table_names(self, abs_dir_path: str) -> list[str]:
settings = self._get_request_settings()

def callee() -> ydb.Directory:
return self._driver.scheme_client.list_directory(abs_dir_path)
return self._driver.scheme_client.list_directory(
abs_dir_path,
settings=settings,
)

directory = retry_operation_sync(callee)
result = []
Expand Down Expand Up @@ -300,6 +334,7 @@ def cursor(self) -> AsyncCursor:
tx_mode=self._tx_mode,
tx_context=self._tx_context,
table_path_prefix=self.table_path_prefix,
request_settings=self.request_settings,
)

async def wait_ready(self, timeout: int = 10) -> None:
Expand All @@ -325,15 +360,17 @@ async def begin(self) -> None:
@handle_ydb_errors
async def commit(self) -> None:
if self._session and self._tx_context and self._tx_context.tx_id:
await self._tx_context.commit()
settings = self._get_request_settings()
await self._tx_context.commit(settings=settings)
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None

@handle_ydb_errors
async def rollback(self) -> None:
if self._session and self._tx_context and self._tx_context.tx_id:
await self._tx_context.rollback()
settings = self._get_request_settings()
await self._tx_context.rollback(settings=settings)
await self._session_pool.release(self._session)
self._session = None
self._tx_context = None
Expand All @@ -351,10 +388,15 @@ async def close(self) -> None:

@handle_ydb_errors
async def describe(self, table_path: str) -> ydb.TableSchemeEntry:
settings = self._get_request_settings()

abs_table_path = posixpath.join(
self.database, self.table_path_prefix, table_path
)
return await self._driver.table_client.describe_table(abs_table_path)
return await self._driver.table_client.describe_table(
abs_table_path,
settings=settings,
)

@handle_ydb_errors
async def check_exists(self, table_path: str) -> bool:
Expand All @@ -371,9 +413,13 @@ async def get_table_names(self) -> list[str]:

async def _check_path_exists(self, table_path: str) -> bool:
try:
settings = self._get_request_settings()

async def callee() -> None:
await self._driver.scheme_client.describe_path(table_path)
await self._driver.scheme_client.describe_path(
table_path,
settings=settings,
)

await retry_operation_async(callee)
except ydb.SchemeError:
Expand All @@ -382,9 +428,12 @@ async def callee() -> None:
return True

async def _get_table_names(self, abs_dir_path: str) -> list[str]:
settings = self._get_request_settings()

async def callee() -> ydb.Directory:
return await self._driver.scheme_client.list_directory(
abs_dir_path
abs_dir_path,
settings=settings,
)

directory = await retry_operation_async(callee)
Expand Down
37 changes: 37 additions & 0 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .errors import ProgrammingError
from .utils import CursorStatus
from .utils import handle_ydb_errors
from .utils import maybe_get_current_trace_id

ParametersType = dict[
str,
Expand Down Expand Up @@ -148,12 +149,14 @@ def __init__(
self,
session_pool: ydb.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix

Expand All @@ -169,16 +172,27 @@ def fetchmany(self, size: int | None = None) -> list:
def fetchall(self) -> list:
return self._fetchall_from_buffer()

def _get_request_settings(self) -> ydb.BaseRequestSettings:
settings = self._request_settings.make_copy()

if self._request_settings.trace_id is None:
settings = settings.with_trace_id(maybe_get_current_trace_id())

return settings

@handle_ydb_errors
def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> Iterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()

def callee(
session: ydb.QuerySession,
) -> Iterator[ydb.convert.ResultSet]:
return session.execute(
query=query,
parameters=parameters,
settings=settings,
)

return self._session_pool.retry_operation_sync(callee)
Expand All @@ -189,13 +203,16 @@ def _execute_session_query(
query: str,
parameters: ParametersType | None = None,
) -> Iterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()

def callee(
session: ydb.QuerySession,
) -> Iterator[ydb.convert.ResultSet]:
return session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
settings=settings,
)

return self._session_pool.retry_operation_sync(callee)
Expand All @@ -207,10 +224,12 @@ def _execute_transactional_query(
query: str,
parameters: ParametersType | None = None,
) -> Iterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()
return tx_context.execute(
query=query,
parameters=parameters,
commit_tx=False,
settings=settings,
)

def execute_scheme(
Expand Down Expand Up @@ -304,12 +323,14 @@ def __init__(
self,
session_pool: ydb.aio.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
self._tx_context = tx_context
self._table_path_prefix = table_path_prefix

Expand All @@ -325,16 +346,27 @@ async def fetchmany(self, size: int | None = None) -> list:
async def fetchall(self) -> list:
return self._fetchall_from_buffer()

def _get_request_settings(self) -> ydb.BaseRequestSettings:
settings = self._request_settings.make_copy()

if self._request_settings.trace_id is None:
settings = settings.with_trace_id(maybe_get_current_trace_id())

return settings

@handle_ydb_errors
async def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> AsyncIterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()

async def callee(
session: ydb.aio.QuerySession,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await session.execute(
query=query,
parameters=parameters,
settings=settings,
)

return await self._session_pool.retry_operation_async(callee)
Expand All @@ -345,13 +377,16 @@ async def _execute_session_query(
query: str,
parameters: ParametersType | None = None,
) -> AsyncIterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()

async def callee(
session: ydb.aio.QuerySession,
) -> AsyncIterator[ydb.convert.ResultSet]:
return await session.transaction(self._tx_mode).execute(
query=query,
parameters=parameters,
commit_tx=True,
settings=settings,
)

return await self._session_pool.retry_operation_async(callee)
Expand All @@ -363,10 +398,12 @@ async def _execute_transactional_query(
query: str,
parameters: ParametersType | None = None,
) -> AsyncIterator[ydb.convert.ResultSet]:
settings = self._get_request_settings()
return await tx_context.execute(
query=query,
parameters=parameters,
commit_tx=False,
settings=settings,
)

async def execute_scheme(
Expand Down
17 changes: 17 additions & 0 deletions ydb_dbapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import functools
import importlib.util
from enum import Enum
from inspect import iscoroutinefunction
from typing import Any
Expand Down Expand Up @@ -100,3 +103,17 @@ class CursorStatus(str, Enum):
running = "running"
finished = "finished"
closed = "closed"


def maybe_get_current_trace_id() -> str | None:
# Check if OpenTelemetry is available
if importlib.util.find_spec("opentelemetry"):
from opentelemetry import trace # type: ignore

current_span = trace.get_current_span()

if current_span.get_span_context().is_valid:
return format(current_span.get_span_context().trace_id, "032x")

# Return None if OpenTelemetry is not available or trace ID is invalid
return None
Loading