Skip to content

Commit e4958e9

Browse files
committed
continue to move common parts from connections
1 parent ef592dc commit e4958e9

File tree

1 file changed

+61
-75
lines changed

1 file changed

+61
-75
lines changed

ydb_dbapi/connections.py

Lines changed: 61 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
import posixpath
4-
from typing import Any
54
from typing import NamedTuple
5+
from typing import TypedDict
66

77
import ydb
8+
from typing_extensions import NotRequired
9+
from typing_extensions import Unpack
810
from ydb.retries import retry_operation_async
911
from ydb.retries import retry_operation_sync
1012

@@ -25,11 +27,59 @@ class IsolationLevel:
2527
AUTOCOMMIT = "AUTOCOMMIT"
2628

2729

30+
class ConnectionKwargs(TypedDict):
31+
credentials: NotRequired[ydb.AbstractCredentials]
32+
ydb_table_path_prefix: NotRequired[str]
33+
ydb_session_pool: NotRequired[
34+
ydb.QuerySessionPool | ydb.aio.QuerySessionPool
35+
]
36+
37+
2838
class BaseConnection:
2939
_tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
3040
_tx_context: ydb.QueryTxContext | ydb.aio.QueryTxContext | None = None
3141
interactive_transaction: bool = False
3242

43+
_shared_session_pool: bool = False
44+
_driver_cls = ydb.Driver
45+
_driver: ydb.Driver | ydb.aio.Driver
46+
_pool_cls = ydb.QuerySessionPool
47+
_pool: ydb.QuerySessionPool | ydb.aio.QuerySessionPool
48+
49+
_current_cursor: AsyncCursor | Cursor | None = None
50+
51+
def __init__(
52+
self,
53+
host: str = "",
54+
port: str = "",
55+
database: str = "",
56+
**conn_kwargs: Unpack[ConnectionKwargs],
57+
) -> None:
58+
self.endpoint = f"grpc://{host}:{port}"
59+
self.database = database
60+
self.conn_kwargs = conn_kwargs
61+
self.credentials = self.conn_kwargs.pop("credentials", None)
62+
self.table_path_prefix = self.conn_kwargs.pop(
63+
"ydb_table_path_prefix", ""
64+
)
65+
66+
if (
67+
"ydb_session_pool" in self.conn_kwargs
68+
): # Use session pool managed manually
69+
self._shared_session_pool = True
70+
self._session_pool = self.conn_kwargs.pop("ydb_session_pool")
71+
self._driver = self._session_pool._driver
72+
else:
73+
driver_config = ydb.DriverConfig(
74+
endpoint=self.endpoint,
75+
database=self.database,
76+
credentials=self.credentials,
77+
)
78+
self._driver = self._driver_cls(driver_config)
79+
self._session_pool = self._pool_cls(self._driver, size=5)
80+
81+
self._session: ydb.QuerySession | ydb.aio.QuerySession | None = None
82+
3383
def set_isolation_level(self, isolation_level: str) -> None:
3484
class IsolationSettings(NamedTuple):
3585
ydb_mode: ydb.BaseQueryTxMode
@@ -82,44 +132,12 @@ def get_isolation_level(self) -> str:
82132

83133

84134
class Connection(BaseConnection):
85-
def __init__(
86-
self,
87-
host: str = "",
88-
port: str = "",
89-
database: str = "",
90-
**conn_kwargs: Any,
91-
) -> None:
92-
self.endpoint = f"grpc://{host}:{port}"
93-
self.database = database
94-
self.conn_kwargs = conn_kwargs
95-
self.credentials = self.conn_kwargs.pop("credentials", None)
96-
self.table_path_prefix = self.conn_kwargs.pop(
97-
"ydb_table_path_prefix", ""
98-
)
135+
_driver_cls = ydb.Driver
136+
_pool_cls = ydb.QuerySessionPool
99137

100-
if (
101-
"ydb_session_pool" in self.conn_kwargs
102-
): # Use session pool managed manually
103-
self._shared_session_pool = True
104-
self._session_pool = self.conn_kwargs.pop("ydb_session_pool")
105-
self._driver: ydb.Driver = self._session_pool._driver
106-
else:
107-
self._shared_session_pool = False
108-
driver_config = ydb.DriverConfig(
109-
endpoint=self.endpoint,
110-
database=self.database,
111-
credentials=self.credentials,
112-
)
113-
self._driver = ydb.Driver(driver_config)
114-
self._session_pool = ydb.QuerySessionPool(self._driver, size=5)
115-
116-
self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
117-
118-
self._current_cursor: Cursor | None = None
119-
self.interactive_transaction: bool = False
120-
121-
self._session: ydb.QuerySession | None = None
122-
self._tx_context: ydb.QueryTxContext | None = None
138+
_driver: ydb.Driver
139+
_pool: ydb.QuerySessionPool
140+
_current_cursor: Cursor | None = None
123141

124142
def wait_ready(self, timeout: int = 10) -> None:
125143
try:
@@ -227,44 +245,12 @@ def callee() -> ydb.Directory:
227245

228246

229247
class AsyncConnection(BaseConnection):
230-
def __init__(
231-
self,
232-
host: str = "",
233-
port: str = "",
234-
database: str = "",
235-
**conn_kwargs: Any,
236-
) -> None:
237-
self.endpoint = f"grpc://{host}:{port}"
238-
self.database = database
239-
self.conn_kwargs = conn_kwargs
240-
self.credentials = self.conn_kwargs.pop("credentials", None)
241-
self.table_path_prefix = self.conn_kwargs.pop(
242-
"ydb_table_path_prefix", ""
243-
)
244-
245-
if (
246-
"ydb_session_pool" in self.conn_kwargs
247-
): # Use session pool managed manually
248-
self._shared_session_pool = True
249-
self._session_pool = self.conn_kwargs.pop("ydb_session_pool")
250-
self._driver: ydb.aio.Driver = self._session_pool._driver
251-
else:
252-
self._shared_session_pool = False
253-
driver_config = ydb.DriverConfig(
254-
endpoint=self.endpoint,
255-
database=self.database,
256-
credentials=self.credentials,
257-
)
258-
self._driver = ydb.aio.Driver(driver_config)
259-
self._session_pool = ydb.aio.QuerySessionPool(self._driver, size=5)
260-
261-
self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite()
262-
263-
self._current_cursor: AsyncCursor | None = None
264-
self.interactive_transaction: bool = False
248+
_driver_cls = ydb.aio.Driver
249+
_pool_cls = ydb.aio.QuerySessionPool
265250

266-
self._session: ydb.aio.QuerySession | None = None
267-
self._tx_context: ydb.aio.QueryTxContext | None = None
251+
_driver: ydb.aio.Driver
252+
_pool: ydb.aio.QuerySessionPool
253+
_current_cursor: AsyncCursor | None = None
268254

269255
async def wait_ready(self, timeout: int = 10) -> None:
270256
try:

0 commit comments

Comments
 (0)