Skip to content

Commit 10011ee

Browse files
committed
custom destination test added
1 parent cf891ff commit 10011ee

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

dlt/destinations/decorators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,9 @@ def wrapper(
291291
**kwargs,
292292
)
293293
elif kwargs.get("destination_callable"):
294-
destination_name = None if func_or_name == "destination" else func_or_name
295294
destination = Destination.from_reference(
296295
ref="destination",
297-
destination_name=destination_name,
296+
destination_name=func_or_name,
298297
**kwargs,
299298
)
300299
else:

tests/load/pipeline/test_pipelines.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22
import gzip
33
import os
4-
from typing import Any, Iterator, List, cast
4+
from typing import Any, Iterator, List, cast, Tuple
55
from pathlib import Path
66
import pytest
77

@@ -15,7 +15,8 @@
1515
from dlt.common.schema.schema import Schema
1616
from dlt.common.schema.typing import VERSION_TABLE_NAME, REPLACE_STRATEGIES, TLoaderReplaceStrategy
1717
from dlt.common.schema.utils import new_table
18-
from dlt.common.typing import TDataItem
18+
from dlt.common.schema import TTableSchema
19+
from dlt.common.typing import TDataItem, TDataItems
1920
from dlt.common.utils import uniq_id
2021

2122
from dlt.destinations.exceptions import DestinationUndefinedEntity
@@ -1228,3 +1229,23 @@ def test_data():
12281229
info = pipeline.run(test_data())
12291230
assert_load_info(info)
12301231
assert (Path(TEST_STORAGE_ROOT) / FILE_BUCKET / pipeline.dataset_name / "test_data").exists()
1232+
1233+
# 9. Should automatically infer destination type as 'dlt.destinations.destination' (custom destination implementation),
1234+
# if destination_callable is provided
1235+
calls: List[Tuple[TDataItems, TTableSchema]] = []
1236+
1237+
def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.value, /) -> None:
1238+
nonlocal calls
1239+
assert my_val == "something"
1240+
calls.append((items, table))
1241+
1242+
os.environ["DESTINATION__MY_VAL"] = "something"
1243+
1244+
p = dlt.pipeline(
1245+
"sink_test",
1246+
destination=dlt.destination("custom_name", destination_callable=local_sink_func),
1247+
)
1248+
assert p.destination.destination_name == "custom_name"
1249+
assert p.destination.destination_type == "dlt.destinations.destination"
1250+
p.run([1, 2, 3], table_name="items")
1251+
assert len(calls) == 1

0 commit comments

Comments
 (0)