From 56d41513e95d800f1c9a72c405c30abe5565a2a5 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 7 Mar 2025 19:07:50 +0300 Subject: [PATCH] Fix pending query issue --- tests/test_connections.py | 99 +++++++++++--------- tests/test_cursors.py | 9 +- ydb_dbapi/connections.py | 4 +- ydb_dbapi/cursors.py | 191 ++++++++++++++++++-------------------- 4 files changed, 152 insertions(+), 151 deletions(-) diff --git a/tests/test_connections.py b/tests/test_connections.py index ba9420d..8470985 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -29,9 +29,11 @@ def _test_isolation_level_read_only( cursor = connection.cursor() with suppress(dbapi.DatabaseError): maybe_await(cursor.execute_scheme("DROP TABLE foo")) - maybe_await(cursor.execute_scheme( - "CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))" - )) + maybe_await( + cursor.execute_scheme( + "CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))" + ) + ) connection.set_isolation_level(isolation_level) cursor = connection.cursor() @@ -60,9 +62,11 @@ def _test_connection(self, connection: dbapi.Connection) -> None: with pytest.raises(dbapi.ProgrammingError): maybe_await(connection.describe("/local/foo")) - maybe_await(cur.execute_scheme( - "CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))" - )) + maybe_await( + cur.execute_scheme( + "CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))" + ) + ) assert maybe_await(connection.check_exists("/local/foo")) @@ -84,26 +88,28 @@ def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))" )) - maybe_await(cur.execute( - """ + maybe_await( + cur.execute( + """ DECLARE $data AS List>; INSERT INTO test SELECT id, text FROM AS_TABLE($data); """, - { - "$data": ydb.TypedValue( - [ - {"id": 17, "text": "seventeen"}, - {"id": 21, "text": "twenty one"}, - ], - ydb.ListType( - ydb.StructType() - .add_member("id", ydb.PrimitiveType.Int64) - .add_member("text", ydb.PrimitiveType.Utf8) - ), - ) - }, - )) + { + "$data": ydb.TypedValue( + [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ], + ydb.ListType( + ydb.StructType() + .add_member("id", ydb.PrimitiveType.Int64) + .add_member("text", ydb.PrimitiveType.Utf8) + ), + ) + }, + ) + ) maybe_await(cur.execute_scheme("DROP TABLE test")) @@ -112,13 +118,15 @@ def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: def _test_errors( self, connection: dbapi.Connection, - connect_method: callable = dbapi.connect + connect_method: callable = dbapi.connect, ) -> None: with pytest.raises(dbapi.InterfaceError): - maybe_await(connect_method( - "localhost:2136", # type: ignore - database="/local666", # type: ignore - )) + maybe_await( + connect_method( + "localhost:2136", # type: ignore + database="/local666", # type: ignore + ) + ) cur = connection.cursor() @@ -137,9 +145,9 @@ def _test_errors( with pytest.raises(dbapi.ProgrammingError): maybe_await(cur.execute("SELECT * FROM test")) - maybe_await(cur.execute_scheme( - "CREATE TABLE test(id Int64, PRIMARY KEY (id))" - )) + maybe_await( + cur.execute_scheme("CREATE TABLE test(id Int64, PRIMARY KEY (id))") + ) maybe_await(cur.execute("INSERT INTO test(id) VALUES(1)")) @@ -154,8 +162,9 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None: with suppress(dbapi.DatabaseError): maybe_await(cursor.execute_scheme("DROP TABLE pet")) - maybe_await(cursor.execute_scheme( - """ + maybe_await( + cursor.execute_scheme( + """ CREATE TABLE pet ( pet_id INT, name TEXT NOT NULL, @@ -165,7 +174,8 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None: PRIMARY KEY (pet_id) ); """ - )) + ) + ) column_types = ( ydb.BulkUpsertColumns() @@ -182,14 +192,14 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None: "name": "Lester", "pet_type": "Hamster", "birth_date": "2020-06-23", - "owner": "Lily" + "owner": "Lily", }, { "pet_id": 4, "name": "Quincy", "pet_type": "Parrot", "birth_date": "2013-08-11", - "owner": "Anne" + "owner": "Anne", }, ] @@ -204,10 +214,10 @@ def _test_error_with_interactive_tx( self, connection: dbapi.Connection, ) -> None: - cur = connection.cursor() - maybe_await(cur.execute_scheme( - """ + maybe_await( + cur.execute_scheme( + """ DROP TABLE IF EXISTS test; CREATE TABLE test ( id Int64 NOT NULL, @@ -215,7 +225,8 @@ def _test_error_with_interactive_tx( PRIMARY KEY(id) ) """ - )) + ) + ) connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE) maybe_await(connection.begin()) @@ -274,8 +285,8 @@ def test_bulk_upsert(self, connection: dbapi.Connection) -> None: self._test_bulk_upsert(connection) def test_errors_with_interactive_tx( - self, connection: dbapi.Connection - ) -> None: + self, connection: dbapi.Connection + ) -> None: self._test_error_with_interactive_tx(connection) @@ -291,8 +302,10 @@ def connect() -> dbapi.AsyncConnection: try: yield conn finally: + def close() -> None: maybe_await(conn.close()) + await greenlet_spawn(close) @pytest.mark.asyncio @@ -315,7 +328,9 @@ async def test_isolation_level_read_only( ) -> None: await greenlet_spawn( self._test_isolation_level_read_only, - connection, isolation_level, read_only + connection, + isolation_level, + read_only, ) @pytest.mark.asyncio diff --git a/tests/test_cursors.py b/tests/test_cursors.py index d941a67..f951121 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -161,7 +161,6 @@ class TestCursor(BaseCursorTestSuit): def sync_cursor( self, session_pool_sync: ydb.QuerySessionPool ) -> Generator[Cursor]: - cursor = Cursor( FakeSyncConnection(), session_pool_sync, @@ -195,9 +194,7 @@ def test_cursor_fetch_all_multiple_result_sets( ) -> None: self._test_cursor_fetch_all_multiple_result_sets(sync_cursor) - def test_cursor_state_after_error( - self, sync_cursor: Cursor - ) -> None: + def test_cursor_state_after_error(self, sync_cursor: Cursor) -> None: self._test_cursor_state_after_error(sync_cursor) @@ -255,6 +252,4 @@ async def test_cursor_fetch_all_multiple_result_sets( async def test_cursor_state_after_error( self, async_cursor: AsyncCursor ) -> None: - await greenlet_spawn( - self._test_cursor_state_after_error, async_cursor - ) + await greenlet_spawn(self._test_cursor_state_after_error, async_cursor) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 06f4bbe..a09250e 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -102,7 +102,9 @@ def __init__( database=self.database, credentials=self.credentials, query_client_settings=self._get_client_settings(), - root_certificates=ydb.load_ydb_root_certificate(root_certificates_path), + root_certificates=ydb.load_ydb_root_certificate( + root_certificates_path + ), ) self._driver = self._driver_cls(driver_config) self._session_pool = self._pool_cls(self._driver, size=5) diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 66b7ac8..41706f9 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -144,6 +144,10 @@ def _update_description(self, result_set: ydb.convert.ResultSet) -> None: for col in result_set.columns ] + def _fill_buffer(self, result_set_list: list) -> None: + for result_set in result_set_list: + self._update_result_set(result_set, replace_current=False) + def _raise_if_running(self) -> None: if self._state == CursorStatus.running: raise ProgrammingError( @@ -164,6 +168,9 @@ def is_closed(self) -> bool: def _begin_query(self) -> None: self._state = CursorStatus.running + def _finish_query(self) -> None: + self._state = CursorStatus.finished + def _fetchone_from_buffer(self) -> tuple | None: self._raise_if_closed() return next(self._rows or iter([]), None) @@ -223,20 +230,27 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings: return settings + def _materialize( + self, stream: Iterator[ydb.convert.ResultSet] + ) -> list[ydb.convert.ResultSet]: + return list(stream) + @handle_ydb_errors @invalidate_cursor_on_ydb_error def _execute_generic_query( self, query: str, parameters: ParametersType | None = None - ) -> Iterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() def callee( session: ydb.QuerySession, - ) -> Iterator[ydb.convert.ResultSet]: - return session.execute( - query=query, - parameters=parameters, - settings=settings, + ) -> list[ydb.convert.ResultSet]: + return self._materialize( + session.execute( + query=query, + parameters=parameters, + settings=settings, + ) ) return self._session_pool.retry_operation_sync(callee) @@ -247,17 +261,19 @@ def _execute_session_query( self, query: str, parameters: ParametersType | None = None, - ) -> Iterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() def callee( session: ydb.QuerySession, - ) -> Iterator[ydb.convert.ResultSet]: - return session.transaction(self._tx_mode).execute( - query=query, - parameters=parameters, - commit_tx=True, - settings=settings, + ) -> list[ydb.convert.ResultSet]: + return self._materialize( + session.transaction(self._tx_mode).execute( + query=query, + parameters=parameters, + commit_tx=True, + settings=settings, + ) ) return self._session_pool.retry_operation_sync(callee) @@ -269,13 +285,15 @@ def _execute_transactional_query( tx_context: ydb.QueryTxContext, query: str, parameters: ParametersType | None = None, - ) -> Iterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() - return tx_context.execute( - query=query, - parameters=parameters, - commit_tx=False, - settings=settings, + return self._materialize( + tx_context.execute( + query=query, + parameters=parameters, + commit_tx=False, + settings=settings, + ) ) def execute_scheme( @@ -286,12 +304,13 @@ def execute_scheme( self._raise_if_closed() query = self._append_table_path_prefix(query) + self._begin_query() - self._stream = self._execute_generic_query( + result_list = self._execute_generic_query( query=query, parameters=parameters ) - self._begin_query() - self._scroll_stream(replace_current=False) + self._fill_buffer(result_list) + self._finish_query() def execute( self, @@ -302,18 +321,19 @@ def execute( self._raise_if_running() query = self._append_table_path_prefix(query) + self._begin_query() if self._tx_context is not None: - self._stream = self._execute_transactional_query( + result_list = self._execute_transactional_query( tx_context=self._tx_context, query=query, parameters=parameters ) else: - self._stream = self._execute_session_query( + result_list = self._execute_session_query( query=query, parameters=parameters ) - self._begin_query() - self._scroll_stream(replace_current=False) + self._fill_buffer(result_list) + self._finish_query() def executemany( self, query: str, seq_of_parameters: Sequence[ParametersType] @@ -321,30 +341,8 @@ def executemany( for parameters in seq_of_parameters: self.execute(query, parameters) - @handle_ydb_errors - @invalidate_cursor_on_ydb_error - def nextset(self, replace_current: bool = True) -> bool: - if self._stream is None: - return False - try: - result_set = self._stream.__next__() - self._update_result_set(result_set, replace_current) - except (StopIteration, StopAsyncIteration, RuntimeError): - self._state = CursorStatus.finished - return False - except ydb.Error: - self._state = CursorStatus.finished - raise - return True - - def _scroll_stream(self, replace_current: bool = True) -> None: - self._raise_if_closed() - - next_set_available = True - while next_set_available: - next_set_available = self.nextset(replace_current) - - self._state = CursorStatus.finished + def nextset(self) -> bool: + return False def close(self) -> None: if self._state == CursorStatus.closed: @@ -402,20 +400,27 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings: return settings + async def _materialize( + self, stream: AsyncIterator[ydb.convert.ResultSet] + ) -> list[ydb.convert.ResultSet]: + return [result_set async for result_set in stream] + @handle_ydb_errors @invalidate_cursor_on_ydb_error async def _execute_generic_query( self, query: str, parameters: ParametersType | None = None - ) -> AsyncIterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() async def callee( session: ydb.aio.QuerySession, - ) -> AsyncIterator[ydb.convert.ResultSet]: - return await session.execute( - query=query, - parameters=parameters, - settings=settings, + ) -> list[ydb.convert.ResultSet]: + return await self._materialize( + await session.execute( + query=query, + parameters=parameters, + settings=settings, + ) ) return await self._session_pool.retry_operation_async(callee) @@ -426,17 +431,19 @@ async def _execute_session_query( self, query: str, parameters: ParametersType | None = None, - ) -> AsyncIterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() async def callee( session: ydb.aio.QuerySession, - ) -> AsyncIterator[ydb.convert.ResultSet]: - return await session.transaction(self._tx_mode).execute( - query=query, - parameters=parameters, - commit_tx=True, - settings=settings, + ) -> list[ydb.convert.ResultSet]: + return await self._materialize( + await session.transaction(self._tx_mode).execute( + query=query, + parameters=parameters, + commit_tx=True, + settings=settings, + ) ) return await self._session_pool.retry_operation_async(callee) @@ -448,13 +455,15 @@ async def _execute_transactional_query( tx_context: ydb.aio.QueryTxContext, query: str, parameters: ParametersType | None = None, - ) -> AsyncIterator[ydb.convert.ResultSet]: + ) -> list[ydb.convert.ResultSet]: settings = self._get_request_settings() - return await tx_context.execute( - query=query, - parameters=parameters, - commit_tx=False, - settings=settings, + return await self._materialize( + await tx_context.execute( + query=query, + parameters=parameters, + commit_tx=False, + settings=settings, + ) ) async def execute_scheme( @@ -465,12 +474,13 @@ async def execute_scheme( self._raise_if_closed() query = self._append_table_path_prefix(query) + self._begin_query() - self._stream = await self._execute_generic_query( + result_list = await self._execute_generic_query( query=query, parameters=parameters ) - self._begin_query() - await self._scroll_stream(replace_current=False) + self._fill_buffer(result_list) + self._finish_query() async def execute( self, @@ -482,17 +492,19 @@ async def execute( query = self._append_table_path_prefix(query) + self._begin_query() + if self._tx_context is not None: - self._stream = await self._execute_transactional_query( + result_list = await self._execute_transactional_query( tx_context=self._tx_context, query=query, parameters=parameters ) else: - self._stream = await self._execute_session_query( + result_list = await self._execute_session_query( query=query, parameters=parameters ) - self._begin_query() - await self._scroll_stream(replace_current=False) + self._fill_buffer(result_list) + self._finish_query() async def executemany( self, query: str, seq_of_parameters: Sequence[ParametersType] @@ -500,31 +512,8 @@ async def executemany( for parameters in seq_of_parameters: await self.execute(query, parameters) - @handle_ydb_errors - @invalidate_cursor_on_ydb_error - async def nextset(self, replace_current: bool = True) -> bool: - if self._stream is None: - return False - try: - result_set = await self._stream.__anext__() - self._update_result_set(result_set, replace_current) - except (StopIteration, StopAsyncIteration, RuntimeError): - self._stream = None - self._state = CursorStatus.finished - return False - except ydb.Error: - self._state = CursorStatus.finished - raise - return True - - async def _scroll_stream(self, replace_current: bool = True) -> None: - self._raise_if_closed() - - next_set_available = True - while next_set_available: - next_set_available = await self.nextset(replace_current) - - self._state = CursorStatus.finished + async def nextset(self) -> bool: + return False def close(self) -> None: if self._state == CursorStatus.closed: