Skip to content

Add collate_fn parameter support to DataLoader #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

BirkhoffG
Copy link
Owner

@BirkhoffG BirkhoffG commented May 26, 2025

This fixes #38

Summary

  • Add collate_fn parameter to DataLoader for custom batch collation functionality
  • Implement support across all three backends (JAX, PyTorch, TensorFlow)
  • Add comprehensive test coverage and documentation

Changes

  • Core DataLoader: Added collate_fn parameter with proper type annotations
  • JAX Backend: Applies collate_fn to batched data from dataset indexing
  • PyTorch Backend: Uses custom collate_fn or defaults to _numpy_collate
  • TensorFlow Backend: Uses tf.data.Dataset.map() for collate_fn application
  • Type System: Added Callable import to typing imports
  • Testing: Added test_collate_fn function in tests.ipynb with cross-backend validation
  • Documentation: Added usage example in core.ipynb

Behavior

The collate_fn parameter behaves consistently with PyTorch's DataLoader, allowing users to customize how individual samples are combined into batches. When collate_fn=None, each backend uses its default collation behavior.

Test plan

  • Test collate_fn functionality with JAX backend
  • Test collate_fn functionality with PyTorch backend
  • Test collate_fn functionality with TensorFlow backend
  • Test that collate_fn=None uses default behavior
  • Test custom transformation functions (e.g., adding constants to features)
  • Verify backward compatibility with existing code
  • Run nbdev_test to ensure all tests pass

- Add collate_fn parameter to core DataLoader class for custom batch collation
- Implement collate_fn support across all backends (JAX, PyTorch, TensorFlow)
- JAX backend applies collate_fn to batched data from dataset indexing
- PyTorch backend uses custom collate_fn or defaults to _numpy_collate
- TensorFlow backend uses tf.data.Dataset.map() for collate_fn application
- Add Callable import to typing imports for proper type annotations
- Add comprehensive test_collate_fn function in tests.ipynb
- Update BaseDataLoader interface to include collate_fn parameter
- Add documentation example in core.ipynb demonstrating usage

The collate_fn parameter behaves consistently with PyTorch's DataLoader,
allowing users to customize how individual samples are combined into batches.
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

BirkhoffG added 3 commits May 26, 2025 16:29
- Fix test_collate_fn to properly handle PyTorch list format vs JAX batched format
- Add support for HuggingFace datasets with PyTorch backend (list of dicts)
- Ensure collate_fn works correctly across all backends (JAX, PyTorch, TensorFlow)
- All collate_fn tests now pass for available backends
- Add tf_collate_wrapper to handle TensorFlow's argument unpacking behavior
- TensorFlow map() calls functions with unpacked args (features, labels)
- Wrapper packs them back into tuple format expected by collate_fn
- Ensures result is properly unpacked for TensorFlow consumption
- Fixes TypeError in CI when TensorFlow is installed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support collate_fn for jdl.DataLoader
1 participant