Add collate_fn parameter support to DataLoader #48
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes #38
Summary
collate_fn
parameter toDataLoader
for custom batch collation functionalityChanges
collate_fn
parameter with proper type annotationscollate_fn
to batched data from dataset indexingcollate_fn
or defaults to_numpy_collate
tf.data.Dataset.map()
forcollate_fn
applicationCallable
import to typing importstest_collate_fn
function intests.ipynb
with cross-backend validationcore.ipynb
Behavior
The
collate_fn
parameter behaves consistently with PyTorch's DataLoader, allowing users to customize how individual samples are combined into batches. Whencollate_fn=None
, each backend uses its default collation behavior.Test plan
collate_fn
functionality with JAX backendcollate_fn
functionality with PyTorch backendcollate_fn
functionality with TensorFlow backendcollate_fn=None
uses default behaviornbdev_test
to ensure all tests pass