From 5870e24a02f4184445376e942f976f2087290a65 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 31 Oct 2024 11:04:32 +0300 Subject: [PATCH 1/2] Fix tx modes --- tests/test_connections.py | 9 +++----- ydb_dbapi/connections.py | 46 ++++++++++++++++++++------------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/tests/test_connections.py b/tests/test_connections.py index 0488f0f..a138f09 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -30,7 +30,6 @@ def _test_isolation_level_read_only( cursor = connection.cursor() with suppress(dbapi.DatabaseError): maybe_await(cursor.execute("DROP TABLE foo")) - cursor = connection.cursor() maybe_await(cursor.execute( "CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))" @@ -38,21 +37,17 @@ def _test_isolation_level_read_only( connection.set_isolation_level(isolation_level) cursor = connection.cursor() - query = "UPSERT INTO foo(id) VALUES (1)" if read_only: with pytest.raises(dbapi.DatabaseError): maybe_await(cursor.execute(query)) - else: maybe_await(cursor.execute(query)) maybe_await(connection.rollback()) connection.set_isolation_level("AUTOCOMMIT") - cursor = connection.cursor() - maybe_await(cursor.execute("DROP TABLE foo")) def _test_connection(self, connection: dbapi.Connection) -> None: @@ -211,7 +206,9 @@ def connect() -> dbapi.AsyncConnection: try: yield conn finally: - await greenlet_spawn(conn.close) + def close() -> None: + maybe_await(conn.close()) + await greenlet_spawn(close) @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index b7c353e..e0033fb 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -30,26 +30,24 @@ class IsolationLevel(str, Enum): class _IsolationSettings(NamedTuple): - ydb_mode: ydb.BaseQueryTxMode + ydb_mode: ydb.BaseQueryTxMode | None interactive: bool _ydb_isolation_settings_map = { - IsolationLevel.AUTOCOMMIT: _IsolationSettings( - ydb.QuerySerializableReadWrite(), interactive=False - ), + IsolationLevel.AUTOCOMMIT: _IsolationSettings(None, interactive=False), IsolationLevel.SERIALIZABLE: _IsolationSettings( ydb.QuerySerializableReadWrite(), interactive=True ), IsolationLevel.ONLINE_READONLY: _IsolationSettings( - ydb.QueryOnlineReadOnly(), interactive=True + ydb.QueryOnlineReadOnly(), interactive=False ), IsolationLevel.ONLINE_READONLY_INCONSISTENT: _IsolationSettings( ydb.QueryOnlineReadOnly().with_allow_inconsistent_reads(), - interactive=True, + interactive=False, ), IsolationLevel.STALE_READONLY: _IsolationSettings( - ydb.QueryStaleReadOnly(), interactive=True + ydb.QueryStaleReadOnly(), interactive=False ), IsolationLevel.SNAPSHOT_READONLY: _IsolationSettings( ydb.QuerySnapshotReadOnly(), interactive=True @@ -78,10 +76,11 @@ def __init__( self.connection_kwargs: dict = kwargs - self._tx_mode: ydb.BaseQueryTxMode = ydb.QuerySerializableReadWrite() + self._shared_session_pool: bool = False + self._tx_context: TxContext | AsyncTxContext | None = None + self._tx_mode: ydb.BaseQueryTxMode | None = None self.interactive_transaction: bool = False - self._shared_session_pool: bool = False if ydb_session_pool is not None: self._shared_session_pool = True @@ -99,21 +98,24 @@ def __init__( self._session: ydb.QuerySession | ydb.aio.QuerySession | None = None def set_isolation_level(self, isolation_level: IsolationLevel) -> None: - ydb_isolation_settings = _ydb_isolation_settings_map[isolation_level] if self._tx_context and self._tx_context.tx_id: raise InternalError( "Failed to set transaction mode: transaction is already began" ) + + ydb_isolation_settings = _ydb_isolation_settings_map[isolation_level] + + self._tx_context = None self._tx_mode = ydb_isolation_settings.ydb_mode self.interactive_transaction = ydb_isolation_settings.interactive def get_isolation_level(self) -> str: - if self._tx_mode.name == ydb.QuerySerializableReadWrite().name: - if self.interactive_transaction: - return IsolationLevel.SERIALIZABLE + if self._tx_mode is None: return IsolationLevel.AUTOCOMMIT + if self._tx_mode.name == ydb.QuerySerializableReadWrite().name: + return IsolationLevel.SERIALIZABLE if self._tx_mode.name == ydb.QueryOnlineReadOnly().name: - if self._tx_mode.settings.allow_inconsistent_reads: + if self._tx_mode.allow_inconsistent_reads: return IsolationLevel.ONLINE_READONLY_INCONSISTENT return IsolationLevel.ONLINE_READONLY if self._tx_mode.name == ydb.QueryStaleReadOnly().name: @@ -123,6 +125,12 @@ def get_isolation_level(self) -> str: msg = f"{self._tx_mode.name} is not supported" raise NotSupportedError(msg) + def _maybe_init_tx( + self, session: ydb.QuerySession | ydb.aio.QuerySession + ) -> None: + if self._tx_context is None and self._tx_mode is not None: + self._tx_context = session.transaction(self._tx_mode) + class Connection(BaseConnection): _driver_cls = ydb.Driver @@ -154,10 +162,7 @@ def cursor(self) -> Cursor: if self._session is None: raise RuntimeError("Connection is not ready, use wait_ready.") - if self.interactive_transaction: - self._tx_context = self._session.transaction(self._tx_mode) - else: - self._tx_context = None + self._maybe_init_tx(self._session) self._current_cursor = self._cursor_cls( session=self._session, @@ -281,10 +286,7 @@ def cursor(self) -> AsyncCursor: if self._session is None: raise RuntimeError("Connection is not ready, use wait_ready.") - if self.interactive_transaction: - self._tx_context = self._session.transaction(self._tx_mode) - else: - self._tx_context = None + self._maybe_init_tx(self._session) self._current_cursor = self._cursor_cls( session=self._session, From d6850b2df8bff80753aeb97972e10b68495f84ad Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 31 Oct 2024 18:11:46 +0300 Subject: [PATCH 2/2] review fixes --- ydb_dbapi/connections.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 0d3473f..72db29b 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -30,7 +30,7 @@ class IsolationLevel(str, Enum): class _IsolationSettings(NamedTuple): - ydb_mode: ydb.BaseQueryTxMode | None + ydb_mode: ydb.BaseQueryTxMode interactive: bool @@ -187,16 +187,19 @@ def wait_ready(self, timeout: int = 10) -> None: self._session = self._session_pool.acquire() + @handle_ydb_errors def commit(self) -> None: if self._tx_context and self._tx_context.tx_id: self._tx_context.commit() self._tx_context = None + @handle_ydb_errors def rollback(self) -> None: if self._tx_context and self._tx_context.tx_id: self._tx_context.rollback() self._tx_context = None + @handle_ydb_errors def close(self) -> None: self.rollback() @@ -311,16 +314,19 @@ async def wait_ready(self, timeout: int = 10) -> None: self._session = await self._session_pool.acquire() + @handle_ydb_errors async def commit(self) -> None: if self._tx_context and self._tx_context.tx_id: await self._tx_context.commit() self._tx_context = None + @handle_ydb_errors async def rollback(self) -> None: if self._tx_context and self._tx_context.tx_id: await self._tx_context.rollback() self._tx_context = None + @handle_ydb_errors async def close(self) -> None: await self.rollback()