Skip to content

Commit 589f603

Browse files
committed
Fix condition in get_collate_for_dataset
1 parent bf29472 commit 589f603

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pytorch_toolbelt/utils/torch_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,8 @@ def get_collate_for_dataset(dataset: Union[Dataset, ConcatDataset]) -> Callable:
347347
collate_fn = default_collate
348348

349349
if hasattr(dataset, "get_collate_fn"):
350-
collate_fn = dataset.get_collate_fn()
351-
352-
if isinstance(dataset, ConcatDataset):
350+
return dataset.get_collate_fn()
351+
elif isinstance(dataset, ConcatDataset):
353352
collates = set(get_collate_for_dataset(ds) for ds in dataset.datasets)
354353
if len(collates) != 1:
355354
raise RuntimeError(

0 commit comments

Comments
 (0)