Skip to content

Commit af51b1f

Browse files
committed
review fixes
1 parent c98d958 commit af51b1f

File tree

3 files changed

+134
-98
lines changed

3 files changed

+134
-98
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ ignore = [
6161
"TRY003", # Allow specifying long messages outside the exception class
6262
"SLF001", # Allow access private member,
6363
"PGH003", # Allow not to specify rule codes
64+
"PLR0913", # Allow to have many arguments in function definition
6465
]
6566
select = ["ALL"]
6667

ydb_dbapi/connections.py

Lines changed: 113 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
import posixpath
4+
from enum import Enum
45
from typing import NamedTuple
5-
from typing import TypedDict
66

77
import ydb
8-
from typing_extensions import NotRequired
9-
from typing_extensions import Unpack
8+
from ydb import QuerySessionPool as SessionPool
9+
from ydb import QueryTxContext as TxContext
10+
from ydb.aio import QuerySessionPool as AsyncSessionPool
11+
from ydb.aio import QueryTxContext as AsyncTxContext
1012
from ydb.retries import retry_operation_async
1113
from ydb.retries import retry_operation_sync
1214

@@ -18,7 +20,7 @@
1820
from .utils import handle_ydb_errors
1921

2022

21-
class IsolationLevel:
23+
class IsolationLevel(str, Enum):
2224
SERIALIZABLE = "SERIALIZABLE"
2325
ONLINE_READONLY = "ONLINE READONLY"
2426
ONLINE_READONLY_INCONSISTENT = "ONLINE READONLY INCONSISTENT"
@@ -27,49 +29,63 @@ class IsolationLevel:
2729
AUTOCOMMIT = "AUTOCOMMIT"
2830

2931

30-
class ConnectionKwargs(TypedDict):
31-
credentials: NotRequired[ydb.AbstractCredentials]
32-
ydb_table_path_prefix: NotRequired[str]
33-
ydb_session_pool: NotRequired[
34-
ydb.QuerySessionPool | ydb.aio.QuerySessionPool
35-
]
32+
class _IsolationSettings(NamedTuple):
33+
ydb_mode: ydb.BaseQueryTxMode
34+
interactive: bool
35+
36+
37+
_ydb_isolation_settings_map = {
38+
IsolationLevel.AUTOCOMMIT: _IsolationSettings(
39+
ydb.QuerySerializableReadWrite(), interactive=False
40+
),
41+
IsolationLevel.SERIALIZABLE: _IsolationSettings(
42+
ydb.QuerySerializableReadWrite(), interactive=True
43+
),
44+
IsolationLevel.ONLINE_READONLY: _IsolationSettings(
45+
ydb.QueryOnlineReadOnly(), interactive=True
46+
),
47+
IsolationLevel.ONLINE_READONLY_INCONSISTENT: _IsolationSettings(
48+
ydb.QueryOnlineReadOnly().with_allow_inconsistent_reads(),
49+
interactive=True,
50+
),
51+
IsolationLevel.STALE_READONLY: _IsolationSettings(
52+
ydb.QueryStaleReadOnly(), interactive=True
53+
),
54+
IsolationLevel.SNAPSHOT_READONLY: _IsolationSettings(
55+
ydb.QuerySnapshotReadOnly(), interactive=True
56+
),
57+
}
3658

3759

3860
class BaseConnection:
39-
_tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
40-
_tx_context: ydb.QueryTxContext | ydb.aio.QueryTxContext | None = None
41-
interactive_transaction: bool = False
42-
_shared_session_pool: bool = False
43-
4461
_driver_cls = ydb.Driver
4562
_pool_cls = ydb.QuerySessionPool
46-
_cursor_cls: type[Cursor | AsyncCursor] = Cursor
47-
48-
_driver: ydb.Driver | ydb.aio.Driver
49-
_pool: ydb.QuerySessionPool | ydb.aio.QuerySessionPool
50-
51-
_current_cursor: AsyncCursor | Cursor | None = None
5263

5364
def __init__(
5465
self,
5566
host: str = "",
5667
port: str = "",
5768
database: str = "",
58-
**conn_kwargs: Unpack[ConnectionKwargs],
69+
ydb_table_path_prefix: str = "",
70+
credentials: ydb.AbstractCredentials | None = None,
71+
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
72+
**kwargs: dict,
5973
) -> None:
6074
self.endpoint = f"grpc://{host}:{port}"
6175
self.database = database
62-
self.conn_kwargs = conn_kwargs
63-
self.credentials = self.conn_kwargs.pop("credentials", None)
64-
self.table_path_prefix = self.conn_kwargs.pop(
65-
"ydb_table_path_prefix", ""
66-
)
76+
self.credentials = credentials
77+
self.table_path_prefix = ydb_table_path_prefix
6778

68-
if (
69-
"ydb_session_pool" in self.conn_kwargs
70-
): # Use session pool managed manually
79+
self.connection_kwargs: dict = kwargs
80+
81+
self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
82+
self._tx_context: TxContext | AsyncTxContext | None = None
83+
self.interactive_transaction: bool = False
84+
self._shared_session_pool: bool = False
85+
86+
if ydb_session_pool is not None:
7187
self._shared_session_pool = True
72-
self._session_pool = self.conn_kwargs.pop("ydb_session_pool")
88+
self._session_pool = ydb_session_pool
7389
self._driver = self._session_pool._driver
7490
else:
7591
driver_config = ydb.DriverConfig(
@@ -82,33 +98,8 @@ def __init__(
8298

8399
self._session: ydb.QuerySession | ydb.aio.QuerySession | None = None
84100

85-
def set_isolation_level(self, isolation_level: str) -> None:
86-
class IsolationSettings(NamedTuple):
87-
ydb_mode: ydb.BaseQueryTxMode
88-
interactive: bool
89-
90-
ydb_isolation_settings_map = {
91-
IsolationLevel.AUTOCOMMIT: IsolationSettings(
92-
ydb.QuerySerializableReadWrite(), interactive=False
93-
),
94-
IsolationLevel.SERIALIZABLE: IsolationSettings(
95-
ydb.QuerySerializableReadWrite(), interactive=True
96-
),
97-
IsolationLevel.ONLINE_READONLY: IsolationSettings(
98-
ydb.QueryOnlineReadOnly(), interactive=True
99-
),
100-
IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings(
101-
ydb.QueryOnlineReadOnly().with_allow_inconsistent_reads(),
102-
interactive=True,
103-
),
104-
IsolationLevel.STALE_READONLY: IsolationSettings(
105-
ydb.QueryStaleReadOnly(), interactive=True
106-
),
107-
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(
108-
ydb.QuerySnapshotReadOnly(), interactive=True
109-
),
110-
}
111-
ydb_isolation_settings = ydb_isolation_settings_map[isolation_level]
101+
def set_isolation_level(self, isolation_level: IsolationLevel) -> None:
102+
ydb_isolation_settings = _ydb_isolation_settings_map[isolation_level]
112103
if self._tx_context and self._tx_context.tx_id:
113104
raise InternalError(
114105
"Failed to set transaction mode: transaction is already began"
@@ -132,7 +123,34 @@ def get_isolation_level(self) -> str:
132123
msg = f"{self._tx_mode.name} is not supported"
133124
raise NotSupportedError(msg)
134125

135-
def cursor(self) -> Cursor | AsyncCursor:
126+
127+
class Connection(BaseConnection):
128+
_driver_cls = ydb.Driver
129+
_pool_cls = ydb.QuerySessionPool
130+
_cursor_cls = Cursor
131+
132+
def __init__(
133+
self,
134+
host: str = "",
135+
port: str = "",
136+
database: str = "",
137+
ydb_table_path_prefix: str = "",
138+
credentials: ydb.AbstractCredentials | None = None,
139+
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
140+
**kwargs: dict,
141+
) -> None:
142+
super().__init__(
143+
host=host,
144+
port=port,
145+
database=database,
146+
ydb_table_path_prefix=ydb_table_path_prefix,
147+
credentials=credentials,
148+
ydb_session_pool=ydb_session_pool,
149+
**kwargs,
150+
)
151+
self._current_cursor: Cursor | None = None
152+
153+
def cursor(self) -> Cursor:
136154
if self._session is None:
137155
raise RuntimeError("Connection is not ready, use wait_ready.")
138156

@@ -148,16 +166,6 @@ def cursor(self) -> Cursor | AsyncCursor:
148166
)
149167
return self._current_cursor
150168

151-
152-
class Connection(BaseConnection):
153-
_driver_cls = ydb.Driver
154-
_pool_cls = ydb.QuerySessionPool
155-
_cursor_cls = Cursor
156-
157-
_driver: ydb.Driver
158-
_pool: ydb.QuerySessionPool
159-
_current_cursor: Cursor | None = None
160-
161169
def wait_ready(self, timeout: int = 10) -> None:
162170
try:
163171
self._driver.wait(timeout, fail_fast=True)
@@ -248,9 +256,42 @@ class AsyncConnection(BaseConnection):
248256
_pool_cls = ydb.aio.QuerySessionPool
249257
_cursor_cls = AsyncCursor
250258

251-
_driver: ydb.aio.Driver
252-
_pool: ydb.aio.QuerySessionPool
253-
_current_cursor: AsyncCursor | None = None
259+
def __init__(
260+
self,
261+
host: str = "",
262+
port: str = "",
263+
database: str = "",
264+
ydb_table_path_prefix: str = "",
265+
credentials: ydb.AbstractCredentials | None = None,
266+
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
267+
**kwargs: dict,
268+
) -> None:
269+
super().__init__(
270+
host=host,
271+
port=port,
272+
database=database,
273+
ydb_table_path_prefix=ydb_table_path_prefix,
274+
credentials=credentials,
275+
ydb_session_pool=ydb_session_pool,
276+
**kwargs,
277+
)
278+
self._current_cursor: AsyncCursor | None = None
279+
280+
def cursor(self) -> AsyncCursor:
281+
if self._session is None:
282+
raise RuntimeError("Connection is not ready, use wait_ready.")
283+
284+
if self.interactive_transaction:
285+
self._tx_context = self._session.transaction(self._tx_mode)
286+
else:
287+
self._tx_context = None
288+
289+
self._current_cursor = self._cursor_cls(
290+
session=self._session,
291+
tx_context=self._tx_context,
292+
autocommit=(not self.interactive_transaction),
293+
)
294+
return self._current_cursor
254295

255296
async def wait_ready(self, timeout: int = 10) -> None:
256297
try:

ydb_dbapi/cursors.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing_extensions import Self
1212

1313
from .errors import DatabaseError
14-
from .errors import Error
1514
from .errors import InterfaceError
1615
from .errors import ProgrammingError
1716
from .utils import CursorStatus
@@ -32,11 +31,12 @@ def _get_column_type(type_obj: Any) -> str:
3231

3332

3433
class BufferedCursor:
35-
arraysize: int = 1
36-
_rows: Iterator | None = None
37-
_rows_count: int = -1
38-
_description: list[tuple] | None = None
39-
_state: CursorStatus = CursorStatus.ready
34+
def __init__(self) -> None:
35+
self.arraysize: int = 1
36+
self._rows: Iterator | None = None
37+
self._rows_count: int = -1
38+
self._description: list[tuple] | None = None
39+
self._state: CursorStatus = CursorStatus.ready
4040

4141
@property
4242
def description(self) -> list[tuple] | None:
@@ -142,6 +142,7 @@ def __init__(
142142
table_path_prefix: str = "",
143143
autocommit: bool = True,
144144
) -> None:
145+
super().__init__()
145146
self._session = session
146147
self._tx_context = tx_context
147148
self._table_path_prefix = table_path_prefix
@@ -167,13 +168,12 @@ def _execute_generic_query(
167168

168169
@handle_ydb_errors
169170
def _execute_transactional_query(
170-
self, query: str, parameters: ParametersType | None = None
171+
self,
172+
tx_context: ydb.QueryTxContext,
173+
query: str,
174+
parameters: ParametersType | None = None,
171175
) -> Iterator[ydb.convert.ResultSet]:
172-
if self._tx_context is None:
173-
raise Error(
174-
"Unable to execute tx based queries without transaction."
175-
)
176-
return self._tx_context.execute(
176+
return tx_context.execute(
177177
query=query,
178178
parameters=parameters,
179179
commit_tx=self._autocommit,
@@ -188,16 +188,13 @@ def execute(
188188
self._raise_if_running()
189189
if self._tx_context is not None:
190190
self._stream = self._execute_transactional_query(
191-
query=query, parameters=parameters
191+
tx_context=self._tx_context, query=query, parameters=parameters
192192
)
193193
else:
194194
self._stream = self._execute_generic_query(
195195
query=query, parameters=parameters
196196
)
197197

198-
if self._stream is None:
199-
return
200-
201198
self._begin_query()
202199

203200
self._scroll_stream(replace_current=False)
@@ -256,6 +253,7 @@ def __init__(
256253
table_path_prefix: str = "",
257254
autocommit: bool = True,
258255
) -> None:
256+
super().__init__()
259257
self._session = session
260258
self._tx_context = tx_context
261259
self._table_path_prefix = table_path_prefix
@@ -281,13 +279,12 @@ async def _execute_generic_query(
281279

282280
@handle_ydb_errors
283281
async def _execute_transactional_query(
284-
self, query: str, parameters: ParametersType | None = None
282+
self,
283+
tx_context: ydb.aio.QueryTxContext,
284+
query: str,
285+
parameters: ParametersType | None = None,
285286
) -> AsyncIterator[ydb.convert.ResultSet]:
286-
if self._tx_context is None:
287-
raise Error(
288-
"Unable to execute tx based queries without transaction."
289-
)
290-
return await self._tx_context.execute(
287+
return await tx_context.execute(
291288
query=query,
292289
parameters=parameters,
293290
commit_tx=self._autocommit,
@@ -302,16 +299,13 @@ async def execute(
302299
self._raise_if_running()
303300
if self._tx_context is not None:
304301
self._stream = await self._execute_transactional_query(
305-
query=query, parameters=parameters
302+
tx_context=self._tx_context, query=query, parameters=parameters
306303
)
307304
else:
308305
self._stream = await self._execute_generic_query(
309306
query=query, parameters=parameters
310307
)
311308

312-
if self._stream is None:
313-
return
314-
315309
self._begin_query()
316310

317311
await self._scroll_stream(replace_current=False)

0 commit comments

Comments
 (0)