From df67b77507b018c850da4134b72374f0486259ff Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 21 May 2025 12:08:06 +0300 Subject: [PATCH] Ability to get view names --- examples/basic_example.py | 19 ++++++++++++++++ poetry.lock | 10 ++++----- pyproject.toml | 3 ++- tests/test_connections.py | 46 +++++++++++++++++++++++++++++++++++++++ ydb_dbapi/connections.py | 40 +++++++++++++++++++++++++++------- 5 files changed, 104 insertions(+), 14 deletions(-) create mode 100644 examples/basic_example.py diff --git a/examples/basic_example.py b/examples/basic_example.py new file mode 100644 index 0000000..ce96a6d --- /dev/null +++ b/examples/basic_example.py @@ -0,0 +1,19 @@ +import ydb_dbapi as dbapi + + +def main() -> None: + connection = dbapi.connect( + host="localhost", + port=2136, + database="/local", + ) + + print(f"Existing tables: {connection.get_table_names()}") + print(f"Existing views: {connection.get_view_names()}") + # TODO: fill example + + connection.close() + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index 14bd356..8f3f385 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1699,21 +1699,21 @@ propcache = ">=0.2.0" [[package]] name = "ydb" -version = "3.18.13" +version = "3.21.2" description = "YDB Python SDK" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "ydb-3.18.13-py2.py3-none-any.whl", hash = "sha256:a4309520da740c8aa630a3a2ca0d6d477bdd459f95cc52930742af31c74f12f8"}, - {file = "ydb-3.18.13.tar.gz", hash = "sha256:89ddcb636d6689e1143592a44a5c27b2b403b41b3ab5aeff8fc83b6bb518fd85"}, + {file = "ydb-3.21.2-py2.py3-none-any.whl", hash = "sha256:5b33ecf936ac61a0641785a066f06c2cd7e36d0499e742538a036d3ec694f1df"}, + {file = "ydb-3.21.2.tar.gz", hash = "sha256:03bbd87b449a12dfdd6fe3e599be5fc41ec5180e2263135b40570b9f6beacba0"}, ] [package.dependencies] aiohttp = "<4" grpcio = ">=1.42.0" packaging = "*" -protobuf = ">=3.13.0,<5.0.0" +protobuf = ">=3.13.0,<6.0.0" [package.extras] yc = ["yandexcloud"] @@ -1721,4 +1721,4 @@ yc = ["yandexcloud"] [metadata] lock-version = "2.1" python-versions = "^3.8" -content-hash = "fe72e73d60e061b1eb8603a09d0a855448432edb6bf4e667cef3beea8de233a1" +content-hash = "4dfe32e606f908c14ff2d6427c9fb38e1a7860d8fa9034c0bb97f46e1ac3a7cf" diff --git a/pyproject.toml b/pyproject.toml index cb0d1c3..8569a3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ repository = "https://github.com/ydb-platform/ydb-python-dbapi/" [tool.poetry.dependencies] python = "^3.8" -ydb = "^3.18.13" +ydb = "^3.18.16" [tool.poetry.group.dev.dependencies] pre-commit = "^3.5.0" @@ -78,6 +78,7 @@ force-single-line = true [tool.ruff.lint.per-file-ignores] "**/test_*.py" = ["S", "SLF", "ANN201", "ARG", "PLR2004", "PT012"] +"examples/*.py" = ["T201", "INP001"] "conftest.py" = ["S", "ARG001"] "__init__.py" = ["F401", "F403"] diff --git a/tests/test_connections.py b/tests/test_connections.py index 8470985..cbfb393 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -239,6 +239,41 @@ def _test_error_with_interactive_tx( maybe_await(cur.close()) maybe_await(connection.rollback()) + def _test_get_view_names( + self, + connection: dbapi.Connection, + ) -> None: + cur = connection.cursor() + + maybe_await( + cur.execute_scheme( + """ + DROP VIEW if exists test_view; + """ + ) + ) + + res = maybe_await(connection.get_view_names()) + + assert len(res) == 0 + + maybe_await( + cur.execute_scheme( + """ + CREATE VIEW test_view WITH (security_invoker = TRUE) as ( + select 1 as res + ); + """ + ) + ) + + res = maybe_await(connection.get_view_names()) + + assert len(res) == 1 + assert res[0] == "test_view" + + maybe_await(cur.close()) + class TestConnection(BaseDBApiTestSuit): @pytest.fixture @@ -289,6 +324,11 @@ def test_errors_with_interactive_tx( ) -> None: self._test_error_with_interactive_tx(connection) + def test_get_view_names( + self, connection: dbapi.Connection + ) -> None: + self._test_get_view_names(connection) + class TestAsyncConnection(BaseDBApiTestSuit): @pytest_asyncio.fixture @@ -358,3 +398,9 @@ async def test_errors_with_interactive_tx( self, connection: dbapi.AsyncConnection ) -> None: await greenlet_spawn(self._test_error_with_interactive_tx, connection) + + @pytest.mark.asyncio + async def test_get_view_names( + self, connection: dbapi.AsyncConnection + ) -> None: + await greenlet_spawn(self._test_get_view_names, connection) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index a09250e..8030b17 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -277,7 +277,13 @@ def check_exists(self, table_path: str) -> bool: @handle_ydb_errors def get_table_names(self) -> list[str]: abs_dir_path = posixpath.join(self.database, self.table_path_prefix) - names = self._get_table_names(abs_dir_path) + names = self._get_entity_names(abs_dir_path, ydb.SchemeEntryType.TABLE) + return [posixpath.relpath(path, abs_dir_path) for path in names] + + @handle_ydb_errors + def get_view_names(self) -> list[str]: + abs_dir_path = posixpath.join(self.database, self.table_path_prefix) + names = self._get_entity_names(abs_dir_path, ydb.SchemeEntryType.VIEW) return [posixpath.relpath(path, abs_dir_path) for path in names] def _check_path_exists(self, table_path: str) -> bool: @@ -295,7 +301,9 @@ def callee() -> None: else: return True - def _get_table_names(self, abs_dir_path: str) -> list[str]: + def _get_entity_names( + self, abs_dir_path: str, etype: ydb.SchemeEntryType + ) -> list[str]: settings = self._get_request_settings() def callee() -> ydb.Directory: @@ -308,10 +316,10 @@ def callee() -> ydb.Directory: result = [] for child in directory.children: child_abs_path = posixpath.join(abs_dir_path, child.name) - if child.is_table(): + if child.type == etype: result.append(child_abs_path) elif child.is_directory() and not child.name.startswith("."): - result.extend(self._get_table_names(child_abs_path)) + result.extend(self._get_entity_names(child_abs_path, etype)) return result @handle_ydb_errors @@ -452,7 +460,19 @@ async def check_exists(self, table_path: str) -> bool: @handle_ydb_errors async def get_table_names(self) -> list[str]: abs_dir_path = posixpath.join(self.database, self.table_path_prefix) - names = await self._get_table_names(abs_dir_path) + names = await self._get_entity_names( + abs_dir_path, + ydb.SchemeEntryType.TABLE, + ) + return [posixpath.relpath(path, abs_dir_path) for path in names] + + @handle_ydb_errors + async def get_view_names(self) -> list[str]: + abs_dir_path = posixpath.join(self.database, self.table_path_prefix) + names = await self._get_entity_names( + abs_dir_path, + ydb.SchemeEntryType.VIEW, + ) return [posixpath.relpath(path, abs_dir_path) for path in names] async def _check_path_exists(self, table_path: str) -> bool: @@ -471,7 +491,9 @@ async def callee() -> None: else: return True - async def _get_table_names(self, abs_dir_path: str) -> list[str]: + async def _get_entity_names( + self, abs_dir_path: str, etype: ydb.SchemeEntryType + ) -> list[str]: settings = self._get_request_settings() async def callee() -> ydb.Directory: @@ -484,10 +506,12 @@ async def callee() -> ydb.Directory: result = [] for child in directory.children: child_abs_path = posixpath.join(abs_dir_path, child.name) - if child.is_table(): + if child.type == etype: result.append(child_abs_path) elif child.is_directory() and not child.name.startswith("."): - result.extend(await self._get_table_names(child_abs_path)) + result.extend( + await self._get_entity_names(child_abs_path, etype) + ) return result @handle_ydb_errors