Skip to content

Commit c61e240

Browse files
committed
Add bulk upsert to connection
1 parent 7b4fac3 commit c61e240

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

tests/test_connections.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,52 @@ def _test_errors(
149149
maybe_await(cur.execute_scheme("DROP TABLE test"))
150150
maybe_await(cur.close())
151151

152+
def _test_bulk_upsert(self, connection: dbapi.Connection):
153+
cursor = connection.cursor()
154+
maybe_await(cursor.execute_scheme(
155+
"""
156+
CREATE TABLE pet (
157+
pet_id INT,
158+
name TEXT NOT NULL,
159+
pet_type TEXT NOT NULL,
160+
birth_date TEXT NOT NULL,
161+
owner TEXT NOT NULL,
162+
PRIMARY KEY (pet_id)
163+
);
164+
"""
165+
))
166+
167+
column_types = (
168+
ydb.BulkUpsertColumns()
169+
.add_column("pet_id", ydb.OptionalType(ydb.PrimitiveType.Int32))
170+
.add_column("name", ydb.PrimitiveType.Utf8)
171+
.add_column("pet_type", ydb.PrimitiveType.Utf8)
172+
.add_column("birth_date", ydb.PrimitiveType.Utf8)
173+
.add_column("owner", ydb.PrimitiveType.Utf8)
174+
)
175+
176+
rows = [
177+
{
178+
"pet_id": 3,
179+
"name": "Lester",
180+
"pet_type": "Hamster",
181+
"birth_date": "2020-06-23",
182+
"owner": "Lily"
183+
},
184+
{
185+
"pet_id": 4,
186+
"name": "Quincy",
187+
"pet_type": "Parrot",
188+
"birth_date": "2013-08-11",
189+
"owner": "Anne"
190+
},
191+
]
192+
193+
maybe_await(connection.bulk_upsert("pet", rows, column_types))
194+
195+
maybe_await(cursor.execute("SELECT * FROM pet"))
196+
assert cursor.rowcount == 2
197+
152198

153199
class TestConnection(BaseDBApiTestSuit):
154200
@pytest.fixture
@@ -191,6 +237,9 @@ def test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
191237
def test_errors(self, connection: dbapi.Connection) -> None:
192238
self._test_errors(connection)
193239

240+
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
241+
self._test_bulk_upsert(connection)
242+
194243

195244
class TestAsyncConnection(BaseDBApiTestSuit):
196245
@pytest_asyncio.fixture
@@ -244,3 +293,9 @@ async def test_cursor_raw_query(
244293
@pytest.mark.asyncio
245294
async def test_errors(self, connection: dbapi.AsyncConnection) -> None:
246295
await greenlet_spawn(self._test_errors, connection)
296+
297+
@pytest.mark.asyncio
298+
async def test_bulk_upsert(
299+
self, connection: dbapi.AsyncConnection
300+
) -> None:
301+
await greenlet_spawn(self._test_bulk_upsert, connection)

ydb_dbapi/connections.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import posixpath
4+
from collections.abc import Sequence
45
from enum import Enum
56
from typing import NamedTuple
67

@@ -301,6 +302,17 @@ def callee() -> ydb.Directory:
301302
result.extend(self._get_table_names(child_abs_path))
302303
return result
303304

305+
@handle_ydb_errors
306+
def bulk_upsert(
307+
self,
308+
table_name: str,
309+
rows: Sequence,
310+
column_types: ydb.BulkUpsertColumns,
311+
) -> None:
312+
self._driver.table_client.bulk_upsert(
313+
table_name, rows=rows, column_types=column_types
314+
)
315+
304316

305317
class AsyncConnection(BaseConnection):
306318
_driver_cls = ydb.aio.Driver
@@ -446,6 +458,17 @@ async def callee() -> ydb.Directory:
446458
result.extend(await self._get_table_names(child_abs_path))
447459
return result
448460

461+
@handle_ydb_errors
462+
async def bulk_upsert(
463+
self,
464+
table_name: str,
465+
rows: Sequence,
466+
column_types: ydb.BulkUpsertColumns,
467+
) -> None:
468+
await self._driver.table_client.bulk_upsert(
469+
table_name, rows=rows, column_types=column_types
470+
)
471+
449472

450473
def connect(*args: tuple, **kwargs: dict) -> Connection:
451474
conn = Connection(*args, **kwargs) # type: ignore

0 commit comments

Comments
 (0)