Skip to content

Add collate_fn parameter support to DataLoader #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax_dataloader/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
'jax_dataloader.loaders.torch.to_torch_dataset': ( 'loader.torch.html#to_torch_dataset',
'jax_dataloader/loaders/torch.py')},
'jax_dataloader.tests': { 'jax_dataloader.tests.get_batch': ('tests.html#get_batch', 'jax_dataloader/tests.py'),
'jax_dataloader.tests.test_collate_fn': ('tests.html#test_collate_fn', 'jax_dataloader/tests.py'),
'jax_dataloader.tests.test_dataloader': ('tests.html#test_dataloader', 'jax_dataloader/tests.py'),
'jax_dataloader.tests.test_no_shuffle': ('tests.html#test_no_shuffle', 'jax_dataloader/tests.py'),
'jax_dataloader.tests.test_no_shuffle_drop_last': ( 'tests.html#test_no_shuffle_drop_last',
Expand Down
2 changes: 2 additions & 0 deletions jax_dataloader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
shuffle: bool = False, # If true, dataloader reshuffles every epoch
drop_last: bool = False, # If true, drop the last incomplete batch
generator: Optional[GeneratorType] = None, # Random seed generator
collate_fn: Optional[Callable] = None, # Function to collate samples into batches
**kwargs
):
dl_cls = _dispatch_dataloader(backend)
Expand All @@ -107,6 +108,7 @@ def __init__(
shuffle=shuffle,
drop_last=drop_last,
generator=generator,
collate_fn=collate_fn,
**kwargs
)

Expand Down
1 change: 1 addition & 0 deletions jax_dataloader/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Literal,
Union,
Annotated,
Callable,
)
import jax
from jax import vmap, grad, jit, numpy as jnp, random as jrand
Expand Down
1 change: 1 addition & 0 deletions jax_dataloader/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
num_workers: int = 0, # how many subprocesses to use for data loading.
drop_last: bool = False,
generator: Optional[GeneratorType] = None,
collate_fn: Optional[Callable] = None, # function to collate samples into batches
**kwargs
):
pass
Expand Down
14 changes: 9 additions & 5 deletions jax_dataloader/loaders/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
def EpochIterator(
data,
batch_size: int,
indices: Sequence[int]
indices: Sequence[int],
collate_fn: Optional[Callable] = None
):
for i in range(0, len(indices), batch_size):
idx = indices[i:i+batch_size]
yield data[idx]
batch = data[idx]
if collate_fn is not None:
batch = collate_fn(batch)
yield batch

# %% ../../nbs/loader.jax.ipynb 5
@dispatch
Expand All @@ -48,6 +52,7 @@ def __init__(
num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.
drop_last: bool = False, # if true, drop the last incomplete batch
generator: Optional[GeneratorType] = None, # random seed generator
collate_fn: Optional[Callable] = None, # function to collate samples into batches
**kwargs
):
self.dataset = to_jax_dataset(dataset)
Expand All @@ -56,6 +61,7 @@ def __init__(
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn

# init rng key via generator
if generator is None:
Expand All @@ -72,7 +78,7 @@ def __iter__(self):

if self.drop_last:
indices = indices[:len(self.indices) - len(self.indices) % self.batch_size]
return EpochIterator(self.dataset, self.batch_size, indices)
return EpochIterator(self.dataset, self.batch_size, indices, self.collate_fn)

def next_key(self):
self.key, subkey = jrand.split(self.key)
Expand All @@ -81,5 +87,3 @@ def next_key(self):
def __len__(self):
complete_batches, remainder = divmod(len(self.indices), self.batch_size)
return complete_batches if self.drop_last else complete_batches + bool(remainder)

# %%
20 changes: 20 additions & 0 deletions jax_dataloader/loaders/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[GeneratorType] = None, # Random seed generator
collate_fn: Optional[Callable] = None, # Function to collate samples into batches
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
Expand All @@ -62,6 +63,25 @@ def __init__(
ds = to_tf_dataset(dataset)
ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds
ds = ds.batch(batch_size, drop_remainder=drop_last)

# Apply collate_fn if provided
if collate_fn is not None:
# TensorFlow map unpacks arguments, so we need a wrapper that packs them back
def tf_collate_wrapper(*args):
# Pack arguments back into a tuple to match expected collate_fn interface
if len(args) == 1:
batch = args[0]
else:
batch = args # tuple of (X, y, ...)
result = collate_fn(batch)
# Ensure result is unpacked for TensorFlow
if isinstance(result, (tuple, list)):
return result
else:
return (result,)

ds = ds.map(tf_collate_wrapper, num_parallel_calls=tf.data.AUTOTUNE)

ds = ds.prefetch(tf.data.AUTOTUNE)
self.dataloader = ds

Expand Down
7 changes: 6 additions & 1 deletion jax_dataloader/loaders/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
shuffle: bool = False, # If true, dataloader shuffles before sampling each batch
drop_last: bool = False, # Drop last batch or not
generator: Optional[GeneratorType] = None,
collate_fn: Optional[Callable] = None, # Function to collate samples into batches
**kwargs
):
super().__init__(dataset, batch_size, shuffle, drop_last)
Expand Down Expand Up @@ -78,13 +79,17 @@ def __init__(
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)

# Use custom collate_fn if provided, otherwise use default numpy collate
if collate_fn is None:
collate_fn = _numpy_collate

self.dataloader = torch_data.DataLoader(
dataset,
batch_sampler=batch_sampler,
# batch_size=batch_size,
# shuffle=shuffle,
# drop_last=drop_last,
collate_fn=_numpy_collate,
collate_fn=collate_fn,
**kwargs
)

Expand Down
70 changes: 68 additions & 2 deletions jax_dataloader/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .imports import *
from .datasets import ArrayDataset
import jax_dataloader as jdl
from jax.tree_util import tree_map

# %% auto 0
__all__ = ['test_shuffle_reproducible', 'test_dataloader']
__all__ = ['test_collate_fn', 'test_shuffle_reproducible', 'test_dataloader']

# %% ../nbs/tests.ipynb 3
def get_batch(batch):
Expand Down Expand Up @@ -83,6 +84,70 @@ def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
assert len(_X) == len(X_list) * batch_size

# %% ../nbs/tests.ipynb 8
def test_collate_fn(cls, ds, batch_size: int):
"""Test that collate_fn parameter works correctly"""

def custom_collate(batch):
if isinstance(batch, dict):
# HuggingFace format (already batched)
return {'feats': batch['feats'] + 1.0, 'labels': batch['labels']}
elif isinstance(batch, list):
# PyTorch format: list of individual samples
if len(batch) > 0:
if isinstance(batch[0], dict):
# List of dictionaries (HuggingFace with PyTorch backend)
# Convert to batched dict format
keys = batch[0].keys()
result = {}
for key in keys:
values = [item[key] for item in batch]
if key == 'feats':
result[key] = np.stack(values) + 1.0
else:
result[key] = np.array(values)
return result
elif isinstance(batch[0], tuple):
# List of tuples: [(x1, y1), (x2, y2), ...]
X_list, y_list = zip(*batch)
X = np.stack(X_list)
y = np.array(y_list)
return X + 1.0, y
else:
# List of individual arrays
return np.array(batch) + 1.0
elif isinstance(batch, tuple):
# JAX/TF format: already batched tuple (X, y)
X, y = batch
if isinstance(X, torch.Tensor):
X, y = tree_map(np.asarray, (X, y))
return X + 1.0, y
else:
# Single array - already batched
return batch + 1.0

# Test without collate_fn (baseline)
dl_normal = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False)
first_batch_normal = next(iter(dl_normal))

# Test with collate_fn
dl_collate = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=custom_collate)
first_batch_collate = next(iter(dl_collate))

# Verify collate_fn was applied
if isinstance(first_batch_normal, dict):
# HuggingFace format
normal_feats = first_batch_normal['feats']
collate_feats = first_batch_collate['feats']
assert np.array_equal(collate_feats, normal_feats + 1.0), "collate_fn should add 1.0 to features"
assert np.array_equal(first_batch_collate['labels'], first_batch_normal['labels']), "labels should be unchanged"
else:
# Tuple format
normal_X, normal_y = first_batch_normal
collate_X, collate_y = first_batch_collate
assert np.array_equal(collate_X, normal_X + 1.0), "collate_fn should add 1.0 to features"
assert np.array_equal(collate_y, normal_y), "labels should be unchanged"

# %% ../nbs/tests.ipynb 9
def test_shuffle_reproducible(cls, ds, batch_size: int, feats, labels):
"""Test that the shuffle is reproducible"""
def _iter_dataloader(dataloader):
Expand All @@ -107,7 +172,7 @@ def _iter_dataloader(dataloader):
X_list_3, Y_list_3 = _iter_dataloader(dl_3)
assert not jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_3))

# %% ../nbs/tests.ipynb 9
# %% ../nbs/tests.ipynb 10
def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12):
feats = np.arange(samples).repeat(10).reshape(samples, 10)
labels = np.arange(samples).reshape(samples, 1)
Expand All @@ -129,3 +194,4 @@ def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12):
test_shuffle(cls, ds, batch_size, feats, labels)
test_shuffle_drop_last(cls, ds, batch_size, feats, labels)
test_shuffle_reproducible(cls, ds, batch_size, feats, labels)
test_collate_fn(cls, ds, batch_size)
51 changes: 16 additions & 35 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,43 +159,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class DataLoader:\n",
" \"\"\"Main Dataloader class to load Numpy data batches\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" dataset, # Dataset from which to load the data\n",
" backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n",
" batch_size: int = 1, # How many samples per batch to load\n",
" shuffle: bool = False, # If true, dataloader reshuffles every epoch\n",
" drop_last: bool = False, # If true, drop the last incomplete batch\n",
" generator: Optional[GeneratorType] = None, # Random seed generator\n",
" **kwargs\n",
" ):\n",
" dl_cls = _dispatch_dataloader(backend)\n",
" self.dataloader = dl_cls(\n",
" dataset=dataset, \n",
" batch_size=batch_size, \n",
" shuffle=shuffle, \n",
" drop_last=drop_last,\n",
" generator=generator,\n",
" **kwargs\n",
" )\n",
"\n",
" def __len__(self):\n",
" return len(self.dataloader)\n",
"\n",
" def __next__(self):\n",
" return next(self.dataloader)\n",
"\n",
" def __iter__(self):\n",
" return iter(self.dataloader)"
]
"source": "#| export\nclass DataLoader:\n \"\"\"Main Dataloader class to load Numpy data batches\"\"\"\n\n def __init__(\n self,\n dataset, # Dataset from which to load the data\n backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n batch_size: int = 1, # How many samples per batch to load\n shuffle: bool = False, # If true, dataloader reshuffles every epoch\n drop_last: bool = False, # If true, drop the last incomplete batch\n generator: Optional[GeneratorType] = None, # Random seed generator\n collate_fn: Optional[Callable] = None, # Function to collate samples into batches\n **kwargs\n ):\n dl_cls = _dispatch_dataloader(backend)\n self.dataloader = dl_cls(\n dataset=dataset, \n batch_size=batch_size, \n shuffle=shuffle, \n drop_last=drop_last,\n generator=generator,\n collate_fn=collate_fn,\n **kwargs\n )\n\n def __len__(self):\n return len(self.dataloader)\n\n def __next__(self):\n return next(self.dataloader)\n\n def __iter__(self):\n return iter(self.dataloader)"
},
{
"attachments": {},
Expand Down Expand Up @@ -348,6 +314,21 @@
"w = train(dataloader, next(keys)).block_until_ready()\n",
"# assert np.allclose(eval(dataloader, w), 0.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `collate_fn` functionality\n",
"\n",
"Test that custom `collate_fn` parameter works correctly by applying a transformation to batches."
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": "# Test collate_fn functionality\ndef custom_collate(batch):\n \"\"\"Custom collate function that adds 1.0 to features while keeping labels unchanged\"\"\"\n X, y = batch\n return X + 1.0, y\n\n# Create test dataset and dataloader with custom collate_fn\ntest_dataset = ArrayDataset(np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 1, 2]))\ndataloader = DataLoader(test_dataset, 'jax', batch_size=2, collate_fn=custom_collate)\n\n# Verify collate_fn transforms features: [1,2] -> [2,3], [3,4] -> [4,5]\nfor X, y in dataloader:\n assert np.array_equal(X[0], [2, 3]) # [1,2] + 1.0 = [2,3]\n assert np.array_equal(X[1], [4, 5]) # [3,4] + 1.0 = [4,5]\n assert np.array_equal(y, [0, 1]) # labels unchanged\n break"
}
],
"metadata": {
Expand Down
28 changes: 1 addition & 27 deletions nbs/loader.base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class BaseDataLoader:\n",
" \"\"\"Dataloader Interface\"\"\"\n",
" \n",
" def __init__(\n",
" self, \n",
" dataset, \n",
" batch_size: int = 1, # batch size\n",
" shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n",
" num_workers: int = 0, # how many subprocesses to use for data loading.\n",
" drop_last: bool = False,\n",
" generator: Optional[GeneratorType] = None,\n",
" **kwargs\n",
" ):\n",
" pass\n",
"\n",
" def __len__(self):\n",
" raise NotImplementedError\n",
" \n",
" def __next__(self):\n",
" raise NotImplementedError\n",
" \n",
" def __iter__(self):\n",
" raise NotImplementedError"
]
"source": "#| export\nclass BaseDataLoader:\n \"\"\"Dataloader Interface\"\"\"\n \n def __init__(\n self, \n dataset, \n batch_size: int = 1, # batch size\n shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n num_workers: int = 0, # how many subprocesses to use for data loading.\n drop_last: bool = False,\n generator: Optional[GeneratorType] = None,\n collate_fn: Optional[Callable] = None, # function to collate samples into batches\n **kwargs\n ):\n pass\n\n def __len__(self):\n raise NotImplementedError\n \n def __next__(self):\n raise NotImplementedError\n \n def __iter__(self):\n raise NotImplementedError"
}
],
"metadata": {
Expand Down
Loading
Loading