Skip to content

Commit 3a67094

Browse files
committed
Invalidate session&tx on YDB errors
1 parent eff6be5 commit 3a67094

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed

tests/test_connections.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,34 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None:
200200

201201
maybe_await(cursor.execute_scheme("DROP TABLE pet"))
202202

203+
def _test_error_with_interactive_tx(
204+
self,
205+
connection: dbapi.Connection,
206+
) -> None:
207+
208+
cur = connection.cursor()
209+
cur.execute_scheme(
210+
"""
211+
DROP TABLE IF EXISTS test;
212+
CREATE TABLE test (
213+
id Int64 NOT NULL,
214+
val Int64,
215+
PRIMARY KEY(id)
216+
)
217+
"""
218+
)
219+
220+
connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE)
221+
maybe_await(connection.begin())
222+
223+
cur = connection.cursor()
224+
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))
225+
with pytest.raises(dbapi.Error):
226+
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))
227+
228+
maybe_await(cur.close())
229+
maybe_await(connection.rollback())
230+
203231

204232
class TestConnection(BaseDBApiTestSuit):
205233
@pytest.fixture
@@ -245,6 +273,11 @@ def test_errors(self, connection: dbapi.Connection) -> None:
245273
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
246274
self._test_bulk_upsert(connection)
247275

276+
def test_errors_with_interactive_tx(
277+
self, connection: dbapi.Connection
278+
) -> None:
279+
self._test_error_with_interactive_tx(connection)
280+
248281

249282
class TestAsyncConnection(BaseDBApiTestSuit):
250283
@pytest_asyncio.fixture
@@ -304,3 +337,9 @@ async def test_bulk_upsert(
304337
self, connection: dbapi.AsyncConnection
305338
) -> None:
306339
await greenlet_spawn(self._test_bulk_upsert, connection)
340+
341+
@pytest.mark.asyncio
342+
async def test_errors_with_interactive_tx(
343+
self, connection: dbapi.AsyncConnection
344+
) -> None:
345+
await greenlet_spawn(self._test_error_with_interactive_tx, connection)

tests/test_cursors.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlalchemy.util import greenlet_spawn
1111
from ydb_dbapi import AsyncCursor
1212
from ydb_dbapi import Cursor
13+
from ydb_dbapi.utils import CursorStatus
1314

1415

1516
def maybe_await(obj: callable) -> any:
@@ -22,6 +23,14 @@ def maybe_await(obj: callable) -> any:
2223
RESULT_SET_COUNT = 3
2324

2425

26+
class FakeSyncConnection:
27+
def _invalidate_session(self) -> None: ...
28+
29+
30+
class FakeAsyncConnection:
31+
async def _invalidate_session(self) -> None: ...
32+
33+
2534
class BaseCursorTestSuit:
2635
def _test_cursor_fetch_one(self, cursor: Cursor | AsyncCursor) -> None:
2736
yql_text = """
@@ -136,13 +145,24 @@ def _test_cursor_fetch_all_multiple_result_sets(
136145
assert maybe_await(cursor.fetchall()) == []
137146
assert not maybe_await(cursor.nextset())
138147

148+
def _test_cursor_state_after_error(
149+
self, cursor: Cursor | AsyncCursor
150+
) -> None:
151+
query = "INSERT INTO table (id, val) VALUES (0,0)"
152+
with pytest.raises(ydb.Error):
153+
maybe_await(cursor.execute(query=query))
154+
155+
assert cursor._state == CursorStatus.finished
156+
139157

140158
class TestCursor(BaseCursorTestSuit):
141159
@pytest.fixture
142160
def sync_cursor(
143161
self, session_pool_sync: ydb.QuerySessionPool
144162
) -> Generator[Cursor]:
163+
145164
cursor = Cursor(
165+
FakeSyncConnection(),
146166
session_pool_sync,
147167
ydb.QuerySerializableReadWrite(),
148168
request_settings=ydb.BaseRequestSettings(),
@@ -174,6 +194,10 @@ def test_cursor_fetch_all_multiple_result_sets(
174194
) -> None:
175195
self._test_cursor_fetch_all_multiple_result_sets(sync_cursor)
176196

197+
def test_cursor_state_after_error(
198+
self, sync_cursor: Cursor
199+
) -> None:
200+
self._test_cursor_state_after_error(sync_cursor)
177201

178202

179203
class TestAsyncCursor(BaseCursorTestSuit):
@@ -182,6 +206,7 @@ async def async_cursor(
182206
self, session_pool: ydb.aio.QuerySessionPool
183207
) -> AsyncGenerator[Cursor]:
184208
cursor = AsyncCursor(
209+
FakeAsyncConnection(),
185210
session_pool,
186211
ydb.QuerySerializableReadWrite(),
187212
request_settings=ydb.BaseRequestSettings(),
@@ -224,3 +249,11 @@ async def test_cursor_fetch_all_multiple_result_sets(
224249
await greenlet_spawn(
225250
self._test_cursor_fetch_all_multiple_result_sets, async_cursor
226251
)
252+
253+
@pytest.mark.asyncio
254+
async def test_cursor_state_after_error(
255+
self, async_cursor: AsyncCursor
256+
) -> None:
257+
await greenlet_spawn(
258+
self._test_cursor_state_after_error, async_cursor
259+
)

ydb_dbapi/connections.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def __init__(
192192

193193
def cursor(self) -> Cursor:
194194
return self._cursor_cls(
195+
connection=self,
195196
session_pool=self._session_pool,
196197
tx_mode=self._tx_mode,
197198
tx_context=self._tx_context,
@@ -326,6 +327,13 @@ def bulk_upsert(
326327
settings=settings,
327328
)
328329

330+
def _invalidate_session(self) -> None:
331+
if self._tx_context:
332+
self._tx_context = None
333+
if self._session:
334+
self._session_pool.release(self._session)
335+
self._session = None
336+
329337

330338
class AsyncConnection(BaseConnection):
331339
_driver_cls = ydb.aio.Driver
@@ -357,6 +365,7 @@ def __init__(
357365

358366
def cursor(self) -> AsyncCursor:
359367
return self._cursor_cls(
368+
connection=self,
360369
session_pool=self._session_pool,
361370
tx_mode=self._tx_mode,
362371
tx_context=self._tx_context,
@@ -492,6 +501,13 @@ async def bulk_upsert(
492501
settings=settings,
493502
)
494503

504+
async def _invalidate_session(self) -> None:
505+
if self._tx_context:
506+
self._tx_context = None
507+
if self._session:
508+
await self._session_pool.release(self._session)
509+
self._session = None
510+
495511

496512
def connect(*args: tuple, **kwargs: dict) -> Connection:
497513
conn = Connection(*args, **kwargs) # type: ignore

ydb_dbapi/cursors.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import functools
34
import itertools
45
from collections.abc import AsyncIterator
56
from collections.abc import Generator
67
from collections.abc import Iterator
78
from collections.abc import Sequence
9+
from inspect import iscoroutinefunction
810
from typing import TYPE_CHECKING
911
from typing import Any
12+
from typing import Callable
1013
from typing import Union
1114

1215
import ydb
@@ -20,6 +23,9 @@
2023
from .utils import maybe_get_current_trace_id
2124

2225
if TYPE_CHECKING:
26+
from .connections import AsyncConnection
27+
from .connections import Connection
28+
2329
ParametersType = dict[
2430
str,
2531
Union[
@@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
3440
return str(ydb.convert.type_to_native(type_obj))
3541

3642

43+
def invalidate_cursor_on_ydb_error(func: Callable) -> Callable:
44+
if iscoroutinefunction(func):
45+
46+
@functools.wraps(func)
47+
async def awrapper(
48+
self: AsyncCursor, *args: tuple, **kwargs: dict
49+
) -> Any:
50+
try:
51+
return await func(self, *args, **kwargs)
52+
except ydb.Error:
53+
self._state = CursorStatus.finished
54+
await self._connection._invalidate_session()
55+
raise
56+
57+
return awrapper
58+
59+
@functools.wraps(func)
60+
def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any:
61+
try:
62+
return func(self, *args, **kwargs)
63+
except ydb.Error:
64+
self._state = CursorStatus.finished
65+
self._connection._invalidate_session()
66+
raise
67+
68+
return wrapper
69+
70+
3771
class BufferedCursor:
3872
def __init__(self) -> None:
3973
self.arraysize: int = 1
@@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
154188
class Cursor(BufferedCursor):
155189
def __init__(
156190
self,
191+
connection: Connection,
157192
session_pool: ydb.QuerySessionPool,
158193
tx_mode: ydb.BaseQueryTxMode,
159194
request_settings: ydb.BaseRequestSettings,
160195
tx_context: ydb.QueryTxContext | None = None,
161196
table_path_prefix: str = "",
162197
) -> None:
163198
super().__init__()
199+
self._connection = connection
164200
self._session_pool = session_pool
165201
self._tx_mode = tx_mode
166202
self._request_settings = request_settings
@@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
188224
return settings
189225

190226
@handle_ydb_errors
227+
@invalidate_cursor_on_ydb_error
191228
def _execute_generic_query(
192229
self, query: str, parameters: ParametersType | None = None
193230
) -> Iterator[ydb.convert.ResultSet]:
@@ -205,6 +242,7 @@ def callee(
205242
return self._session_pool.retry_operation_sync(callee)
206243

207244
@handle_ydb_errors
245+
@invalidate_cursor_on_ydb_error
208246
def _execute_session_query(
209247
self,
210248
query: str,
@@ -225,6 +263,7 @@ def callee(
225263
return self._session_pool.retry_operation_sync(callee)
226264

227265
@handle_ydb_errors
266+
@invalidate_cursor_on_ydb_error
228267
def _execute_transactional_query(
229268
self,
230269
tx_context: ydb.QueryTxContext,
@@ -283,6 +322,7 @@ def executemany(
283322
self.execute(query, parameters)
284323

285324
@handle_ydb_errors
325+
@invalidate_cursor_on_ydb_error
286326
def nextset(self, replace_current: bool = True) -> bool:
287327
if self._stream is None:
288328
return False
@@ -328,13 +368,15 @@ def __exit__(
328368
class AsyncCursor(BufferedCursor):
329369
def __init__(
330370
self,
371+
connection: AsyncConnection,
331372
session_pool: ydb.aio.QuerySessionPool,
332373
tx_mode: ydb.BaseQueryTxMode,
333374
request_settings: ydb.BaseRequestSettings,
334375
tx_context: ydb.aio.QueryTxContext | None = None,
335376
table_path_prefix: str = "",
336377
) -> None:
337378
super().__init__()
379+
self._connection = connection
338380
self._session_pool = session_pool
339381
self._tx_mode = tx_mode
340382
self._request_settings = request_settings
@@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
362404
return settings
363405

364406
@handle_ydb_errors
407+
@invalidate_cursor_on_ydb_error
365408
async def _execute_generic_query(
366409
self, query: str, parameters: ParametersType | None = None
367410
) -> AsyncIterator[ydb.convert.ResultSet]:
@@ -379,6 +422,7 @@ async def callee(
379422
return await self._session_pool.retry_operation_async(callee)
380423

381424
@handle_ydb_errors
425+
@invalidate_cursor_on_ydb_error
382426
async def _execute_session_query(
383427
self,
384428
query: str,
@@ -399,6 +443,7 @@ async def callee(
399443
return await self._session_pool.retry_operation_async(callee)
400444

401445
@handle_ydb_errors
446+
@invalidate_cursor_on_ydb_error
402447
async def _execute_transactional_query(
403448
self,
404449
tx_context: ydb.aio.QueryTxContext,
@@ -457,6 +502,7 @@ async def executemany(
457502
await self.execute(query, parameters)
458503

459504
@handle_ydb_errors
505+
@invalidate_cursor_on_ydb_error
460506
async def nextset(self, replace_current: bool = True) -> bool:
461507
if self._stream is None:
462508
return False

0 commit comments

Comments
 (0)