diff --git a/jax_dataloader/_modidx.py b/jax_dataloader/_modidx.py index a8bac08..2241a8c 100644 --- a/jax_dataloader/_modidx.py +++ b/jax_dataloader/_modidx.py @@ -109,6 +109,7 @@ 'jax_dataloader.loaders.torch.to_torch_dataset': ( 'loader.torch.html#to_torch_dataset', 'jax_dataloader/loaders/torch.py')}, 'jax_dataloader.tests': { 'jax_dataloader.tests.get_batch': ('tests.html#get_batch', 'jax_dataloader/tests.py'), + 'jax_dataloader.tests.test_collate_fn': ('tests.html#test_collate_fn', 'jax_dataloader/tests.py'), 'jax_dataloader.tests.test_dataloader': ('tests.html#test_dataloader', 'jax_dataloader/tests.py'), 'jax_dataloader.tests.test_no_shuffle': ('tests.html#test_no_shuffle', 'jax_dataloader/tests.py'), 'jax_dataloader.tests.test_no_shuffle_drop_last': ( 'tests.html#test_no_shuffle_drop_last', diff --git a/jax_dataloader/core.py b/jax_dataloader/core.py index daeafd1..926493f 100644 --- a/jax_dataloader/core.py +++ b/jax_dataloader/core.py @@ -98,6 +98,7 @@ def __init__( shuffle: bool = False, # If true, dataloader reshuffles every epoch drop_last: bool = False, # If true, drop the last incomplete batch generator: Optional[GeneratorType] = None, # Random seed generator + collate_fn: Optional[Callable] = None, # Function to collate samples into batches **kwargs ): dl_cls = _dispatch_dataloader(backend) @@ -107,6 +108,7 @@ def __init__( shuffle=shuffle, drop_last=drop_last, generator=generator, + collate_fn=collate_fn, **kwargs ) diff --git a/jax_dataloader/imports.py b/jax_dataloader/imports.py index d8bc69f..e923bd5 100644 --- a/jax_dataloader/imports.py +++ b/jax_dataloader/imports.py @@ -12,6 +12,7 @@ Literal, Union, Annotated, + Callable, ) import jax from jax import vmap, grad, jit, numpy as jnp, random as jrand diff --git a/jax_dataloader/loaders/base.py b/jax_dataloader/loaders/base.py index 1168c8e..4ea9146 100644 --- a/jax_dataloader/loaders/base.py +++ b/jax_dataloader/loaders/base.py @@ -21,6 +21,7 @@ def __init__( num_workers: int = 0, # how many subprocesses to use for data loading. drop_last: bool = False, generator: Optional[GeneratorType] = None, + collate_fn: Optional[Callable] = None, # function to collate samples into batches **kwargs ): pass diff --git a/jax_dataloader/loaders/jax.py b/jax_dataloader/loaders/jax.py index 62af712..6f7cf87 100644 --- a/jax_dataloader/loaders/jax.py +++ b/jax_dataloader/loaders/jax.py @@ -19,11 +19,15 @@ def EpochIterator( data, batch_size: int, - indices: Sequence[int] + indices: Sequence[int], + collate_fn: Optional[Callable] = None ): for i in range(0, len(indices), batch_size): idx = indices[i:i+batch_size] - yield data[idx] + batch = data[idx] + if collate_fn is not None: + batch = collate_fn(batch) + yield batch # %% ../../nbs/loader.jax.ipynb 5 @dispatch @@ -48,6 +52,7 @@ def __init__( num_workers: int = 0, # how many subprocesses to use for data loading. Ignored. drop_last: bool = False, # if true, drop the last incomplete batch generator: Optional[GeneratorType] = None, # random seed generator + collate_fn: Optional[Callable] = None, # function to collate samples into batches **kwargs ): self.dataset = to_jax_dataset(dataset) @@ -56,6 +61,7 @@ def __init__( self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last + self.collate_fn = collate_fn # init rng key via generator if generator is None: @@ -72,7 +78,7 @@ def __iter__(self): if self.drop_last: indices = indices[:len(self.indices) - len(self.indices) % self.batch_size] - return EpochIterator(self.dataset, self.batch_size, indices) + return EpochIterator(self.dataset, self.batch_size, indices, self.collate_fn) def next_key(self): self.key, subkey = jrand.split(self.key) @@ -81,5 +87,3 @@ def next_key(self): def __len__(self): complete_batches, remainder = divmod(len(self.indices), self.batch_size) return complete_batches if self.drop_last else complete_batches + bool(remainder) - -# %% diff --git a/jax_dataloader/loaders/tensorflow.py b/jax_dataloader/loaders/tensorflow.py index 73e5bc9..656e221 100644 --- a/jax_dataloader/loaders/tensorflow.py +++ b/jax_dataloader/loaders/tensorflow.py @@ -51,6 +51,7 @@ def __init__( shuffle: bool = False, # If true, dataloader shuffles before sampling each batch drop_last: bool = False, # Drop last batch or not generator: Optional[GeneratorType] = None, # Random seed generator + collate_fn: Optional[Callable] = None, # Function to collate samples into batches **kwargs ): super().__init__(dataset, batch_size, shuffle, drop_last) @@ -62,6 +63,25 @@ def __init__( ds = to_tf_dataset(dataset) ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds ds = ds.batch(batch_size, drop_remainder=drop_last) + + # Apply collate_fn if provided + if collate_fn is not None: + # TensorFlow map unpacks arguments, so we need a wrapper that packs them back + def tf_collate_wrapper(*args): + # Pack arguments back into a tuple to match expected collate_fn interface + if len(args) == 1: + batch = args[0] + else: + batch = args # tuple of (X, y, ...) + result = collate_fn(batch) + # Ensure result is unpacked for TensorFlow + if isinstance(result, (tuple, list)): + return result + else: + return (result,) + + ds = ds.map(tf_collate_wrapper, num_parallel_calls=tf.data.AUTOTUNE) + ds = ds.prefetch(tf.data.AUTOTUNE) self.dataloader = ds diff --git a/jax_dataloader/loaders/torch.py b/jax_dataloader/loaders/torch.py index f270c7b..a01e410 100644 --- a/jax_dataloader/loaders/torch.py +++ b/jax_dataloader/loaders/torch.py @@ -50,6 +50,7 @@ def __init__( shuffle: bool = False, # If true, dataloader shuffles before sampling each batch drop_last: bool = False, # Drop last batch or not generator: Optional[GeneratorType] = None, + collate_fn: Optional[Callable] = None, # Function to collate samples into batches **kwargs ): super().__init__(dataset, batch_size, shuffle, drop_last) @@ -78,13 +79,17 @@ def __init__( sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last) + # Use custom collate_fn if provided, otherwise use default numpy collate + if collate_fn is None: + collate_fn = _numpy_collate + self.dataloader = torch_data.DataLoader( dataset, batch_sampler=batch_sampler, # batch_size=batch_size, # shuffle=shuffle, # drop_last=drop_last, - collate_fn=_numpy_collate, + collate_fn=collate_fn, **kwargs ) diff --git a/jax_dataloader/tests.py b/jax_dataloader/tests.py index 66bb0d2..a568eb0 100644 --- a/jax_dataloader/tests.py +++ b/jax_dataloader/tests.py @@ -5,9 +5,10 @@ from .imports import * from .datasets import ArrayDataset import jax_dataloader as jdl +from jax.tree_util import tree_map # %% auto 0 -__all__ = ['test_shuffle_reproducible', 'test_dataloader'] +__all__ = ['test_collate_fn', 'test_shuffle_reproducible', 'test_dataloader'] # %% ../nbs/tests.ipynb 3 def get_batch(batch): @@ -83,6 +84,70 @@ def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels): assert len(_X) == len(X_list) * batch_size # %% ../nbs/tests.ipynb 8 +def test_collate_fn(cls, ds, batch_size: int): + """Test that collate_fn parameter works correctly""" + + def custom_collate(batch): + if isinstance(batch, dict): + # HuggingFace format (already batched) + return {'feats': batch['feats'] + 1.0, 'labels': batch['labels']} + elif isinstance(batch, list): + # PyTorch format: list of individual samples + if len(batch) > 0: + if isinstance(batch[0], dict): + # List of dictionaries (HuggingFace with PyTorch backend) + # Convert to batched dict format + keys = batch[0].keys() + result = {} + for key in keys: + values = [item[key] for item in batch] + if key == 'feats': + result[key] = np.stack(values) + 1.0 + else: + result[key] = np.array(values) + return result + elif isinstance(batch[0], tuple): + # List of tuples: [(x1, y1), (x2, y2), ...] + X_list, y_list = zip(*batch) + X = np.stack(X_list) + y = np.array(y_list) + return X + 1.0, y + else: + # List of individual arrays + return np.array(batch) + 1.0 + elif isinstance(batch, tuple): + # JAX/TF format: already batched tuple (X, y) + X, y = batch + if isinstance(X, torch.Tensor): + X, y = tree_map(np.asarray, (X, y)) + return X + 1.0, y + else: + # Single array - already batched + return batch + 1.0 + + # Test without collate_fn (baseline) + dl_normal = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False) + first_batch_normal = next(iter(dl_normal)) + + # Test with collate_fn + dl_collate = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=custom_collate) + first_batch_collate = next(iter(dl_collate)) + + # Verify collate_fn was applied + if isinstance(first_batch_normal, dict): + # HuggingFace format + normal_feats = first_batch_normal['feats'] + collate_feats = first_batch_collate['feats'] + assert np.array_equal(collate_feats, normal_feats + 1.0), "collate_fn should add 1.0 to features" + assert np.array_equal(first_batch_collate['labels'], first_batch_normal['labels']), "labels should be unchanged" + else: + # Tuple format + normal_X, normal_y = first_batch_normal + collate_X, collate_y = first_batch_collate + assert np.array_equal(collate_X, normal_X + 1.0), "collate_fn should add 1.0 to features" + assert np.array_equal(collate_y, normal_y), "labels should be unchanged" + +# %% ../nbs/tests.ipynb 9 def test_shuffle_reproducible(cls, ds, batch_size: int, feats, labels): """Test that the shuffle is reproducible""" def _iter_dataloader(dataloader): @@ -107,7 +172,7 @@ def _iter_dataloader(dataloader): X_list_3, Y_list_3 = _iter_dataloader(dl_3) assert not jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_3)) -# %% ../nbs/tests.ipynb 9 +# %% ../nbs/tests.ipynb 10 def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12): feats = np.arange(samples).repeat(10).reshape(samples, 10) labels = np.arange(samples).reshape(samples, 1) @@ -129,3 +194,4 @@ def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12): test_shuffle(cls, ds, batch_size, feats, labels) test_shuffle_drop_last(cls, ds, batch_size, feats, labels) test_shuffle_reproducible(cls, ds, batch_size, feats, labels) + test_collate_fn(cls, ds, batch_size) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index d844cf3..20b8c57 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -159,43 +159,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class DataLoader:\n", - " \"\"\"Main Dataloader class to load Numpy data batches\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " dataset, # Dataset from which to load the data\n", - " backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n", - " batch_size: int = 1, # How many samples per batch to load\n", - " shuffle: bool = False, # If true, dataloader reshuffles every epoch\n", - " drop_last: bool = False, # If true, drop the last incomplete batch\n", - " generator: Optional[GeneratorType] = None, # Random seed generator\n", - " **kwargs\n", - " ):\n", - " dl_cls = _dispatch_dataloader(backend)\n", - " self.dataloader = dl_cls(\n", - " dataset=dataset, \n", - " batch_size=batch_size, \n", - " shuffle=shuffle, \n", - " drop_last=drop_last,\n", - " generator=generator,\n", - " **kwargs\n", - " )\n", - "\n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - "\n", - " def __next__(self):\n", - " return next(self.dataloader)\n", - "\n", - " def __iter__(self):\n", - " return iter(self.dataloader)" - ] + "source": "#| export\nclass DataLoader:\n \"\"\"Main Dataloader class to load Numpy data batches\"\"\"\n\n def __init__(\n self,\n dataset, # Dataset from which to load the data\n backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n batch_size: int = 1, # How many samples per batch to load\n shuffle: bool = False, # If true, dataloader reshuffles every epoch\n drop_last: bool = False, # If true, drop the last incomplete batch\n generator: Optional[GeneratorType] = None, # Random seed generator\n collate_fn: Optional[Callable] = None, # Function to collate samples into batches\n **kwargs\n ):\n dl_cls = _dispatch_dataloader(backend)\n self.dataloader = dl_cls(\n dataset=dataset, \n batch_size=batch_size, \n shuffle=shuffle, \n drop_last=drop_last,\n generator=generator,\n collate_fn=collate_fn,\n **kwargs\n )\n\n def __len__(self):\n return len(self.dataloader)\n\n def __next__(self):\n return next(self.dataloader)\n\n def __iter__(self):\n return iter(self.dataloader)" }, { "attachments": {}, @@ -348,6 +314,21 @@ "w = train(dataloader, next(keys)).block_until_ready()\n", "# assert np.allclose(eval(dataloader, w), 0.)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### `collate_fn` functionality\n", + "\n", + "Test that custom `collate_fn` parameter works correctly by applying a transformation to batches." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": "# Test collate_fn functionality\ndef custom_collate(batch):\n \"\"\"Custom collate function that adds 1.0 to features while keeping labels unchanged\"\"\"\n X, y = batch\n return X + 1.0, y\n\n# Create test dataset and dataloader with custom collate_fn\ntest_dataset = ArrayDataset(np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 1, 2]))\ndataloader = DataLoader(test_dataset, 'jax', batch_size=2, collate_fn=custom_collate)\n\n# Verify collate_fn transforms features: [1,2] -> [2,3], [3,4] -> [4,5]\nfor X, y in dataloader:\n assert np.array_equal(X[0], [2, 3]) # [1,2] + 1.0 = [2,3]\n assert np.array_equal(X[1], [4, 5]) # [3,4] + 1.0 = [4,5]\n assert np.array_equal(y, [0, 1]) # labels unchanged\n break" } ], "metadata": { diff --git a/nbs/loader.base.ipynb b/nbs/loader.base.ipynb index 63901cd..9e0808d 100644 --- a/nbs/loader.base.ipynb +++ b/nbs/loader.base.ipynb @@ -45,35 +45,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class BaseDataLoader:\n", - " \"\"\"Dataloader Interface\"\"\"\n", - " \n", - " def __init__(\n", - " self, \n", - " dataset, \n", - " batch_size: int = 1, # batch size\n", - " shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n", - " num_workers: int = 0, # how many subprocesses to use for data loading.\n", - " drop_last: bool = False,\n", - " generator: Optional[GeneratorType] = None,\n", - " **kwargs\n", - " ):\n", - " pass\n", - "\n", - " def __len__(self):\n", - " raise NotImplementedError\n", - " \n", - " def __next__(self):\n", - " raise NotImplementedError\n", - " \n", - " def __iter__(self):\n", - " raise NotImplementedError" - ] + "source": "#| export\nclass BaseDataLoader:\n \"\"\"Dataloader Interface\"\"\"\n \n def __init__(\n self, \n dataset, \n batch_size: int = 1, # batch size\n shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n num_workers: int = 0, # how many subprocesses to use for data loading.\n drop_last: bool = False,\n generator: Optional[GeneratorType] = None,\n collate_fn: Optional[Callable] = None, # function to collate samples into batches\n **kwargs\n ):\n pass\n\n def __len__(self):\n raise NotImplementedError\n \n def __next__(self):\n raise NotImplementedError\n \n def __iter__(self):\n raise NotImplementedError" } ], "metadata": { diff --git a/nbs/loader.jax.ipynb b/nbs/loader.jax.ipynb index 5b718b4..11d4c90 100644 --- a/nbs/loader.jax.ipynb +++ b/nbs/loader.jax.ipynb @@ -51,20 +51,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def EpochIterator(\n", - " data,\n", - " batch_size: int,\n", - " indices: Sequence[int]\n", - "):\n", - " for i in range(0, len(indices), batch_size):\n", - " idx = indices[i:i+batch_size]\n", - " yield data[idx]" - ] + "source": "#| export\ndef EpochIterator(\n data,\n batch_size: int,\n indices: Sequence[int],\n collate_fn: Optional[Callable] = None\n):\n for i in range(0, len(indices), batch_size):\n idx = indices[i:i+batch_size]\n batch = data[idx]\n if collate_fn is not None:\n batch = collate_fn(batch)\n yield batch" }, { "cell_type": "code", @@ -86,58 +75,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class DataLoaderJAX(BaseDataLoader):\n", - "\n", - " @typecheck\n", - " def __init__(\n", - " self, \n", - " dataset: Union[JAXDataset, HFDataset], \n", - " batch_size: int = 1, # batch size\n", - " shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n", - " num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.\n", - " drop_last: bool = False, # if true, drop the last incomplete batch\n", - " generator: Optional[GeneratorType] = None, # random seed generator\n", - " **kwargs\n", - " ):\n", - " self.dataset = to_jax_dataset(dataset)\n", - " \n", - " self.indices = np.arange(len(dataset))\n", - " self.batch_size = batch_size\n", - " self.shuffle = shuffle\n", - " self.drop_last = drop_last\n", - "\n", - " # init rng key via generator\n", - " if generator is None:\n", - " # explicitly set the manual seed of the generator \n", - " generator = Generator().manual_seed(get_config().global_seed)\n", - " if not isinstance(generator, Generator):\n", - " generator = Generator(generator=generator)\n", - " \n", - " self.key = generator.jax_generator()\n", - " \n", - " def __iter__(self):\n", - " # shuffle (permutation) indices every epoch \n", - " indices = jrand.permutation(self.next_key(), self.indices).__array__() if self.shuffle else self.indices\n", - " \n", - " if self.drop_last:\n", - " indices = indices[:len(self.indices) - len(self.indices) % self.batch_size]\n", - " return EpochIterator(self.dataset, self.batch_size, indices)\n", - " \n", - " def next_key(self):\n", - " self.key, subkey = jrand.split(self.key)\n", - " return subkey\n", - " \n", - " def __len__(self):\n", - " complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n", - " return complete_batches if self.drop_last else complete_batches + bool(remainder)\n", - "\n", - "# %%" - ] + "source": "#| export\nclass DataLoaderJAX(BaseDataLoader):\n\n @typecheck\n def __init__(\n self, \n dataset: Union[JAXDataset, HFDataset], \n batch_size: int = 1, # batch size\n shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n num_workers: int = 0, # how many subprocesses to use for data loading. Ignored.\n drop_last: bool = False, # if true, drop the last incomplete batch\n generator: Optional[GeneratorType] = None, # random seed generator\n collate_fn: Optional[Callable] = None, # function to collate samples into batches\n **kwargs\n ):\n self.dataset = to_jax_dataset(dataset)\n \n self.indices = np.arange(len(dataset))\n self.batch_size = batch_size\n self.shuffle = shuffle\n self.drop_last = drop_last\n self.collate_fn = collate_fn\n\n # init rng key via generator\n if generator is None:\n # explicitly set the manual seed of the generator \n generator = Generator().manual_seed(get_config().global_seed)\n if not isinstance(generator, Generator):\n generator = Generator(generator=generator)\n \n self.key = generator.jax_generator()\n \n def __iter__(self):\n # shuffle (permutation) indices every epoch \n indices = jrand.permutation(self.next_key(), self.indices).__array__() if self.shuffle else self.indices\n \n if self.drop_last:\n indices = indices[:len(self.indices) - len(self.indices) % self.batch_size]\n return EpochIterator(self.dataset, self.batch_size, indices, self.collate_fn)\n \n def next_key(self):\n self.key, subkey = jrand.split(self.key)\n return subkey\n \n def __len__(self):\n complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n return complete_batches if self.drop_last else complete_batches + bool(remainder)" }, { "cell_type": "code", diff --git a/nbs/loader.tf.ipynb b/nbs/loader.tf.ipynb index ec6db30..202e117 100644 --- a/nbs/loader.tf.ipynb +++ b/nbs/loader.tf.ipynb @@ -70,57 +70,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:\n", - " if generator is None:\n", - " generator = Generator()\n", - " \n", - " if not isinstance(generator, Generator):\n", - " generator = Generator(generator=generator)\n", - " \n", - " seed = generator.seed()\n", - " if seed is None:\n", - " warnings.warn(\"No random seed provided. Using default seed which may not guarantee reproducible results.\")\n", - " return seed\n", - "\n", - "class DataLoaderTensorflow(BaseDataLoader):\n", - " \"\"\"Tensorflow Dataloader\"\"\"\n", - " \n", - " @typecheck\n", - " def __init__(\n", - " self, \n", - " dataset: Union[JAXDataset, TFDataset, HFDataset],\n", - " batch_size: int = 1, # Batch size\n", - " shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n", - " drop_last: bool = False, # Drop last batch or not\n", - " generator: Optional[GeneratorType] = None, # Random seed generator\n", - " **kwargs\n", - " ):\n", - " super().__init__(dataset, batch_size, shuffle, drop_last)\n", - " check_tf_installed()\n", - " # get random seed from generator\n", - " seed = get_seed(generator)\n", - "\n", - " # Convert to tf dataset\n", - " ds = to_tf_dataset(dataset)\n", - " ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds\n", - " ds = ds.batch(batch_size, drop_remainder=drop_last)\n", - " ds = ds.prefetch(tf.data.AUTOTUNE)\n", - " self.dataloader = ds\n", - "\n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - "\n", - " def __next__(self):\n", - " return next(self.dataloader)\n", - "\n", - " def __iter__(self):\n", - " return self.dataloader.as_numpy_iterator()" - ] + "source": "#| export\ndef get_seed(generator: Optional[Generator | jax.Array | torch.Generator] = None) -> int:\n if generator is None:\n generator = Generator()\n \n if not isinstance(generator, Generator):\n generator = Generator(generator=generator)\n \n seed = generator.seed()\n if seed is None:\n warnings.warn(\"No random seed provided. Using default seed which may not guarantee reproducible results.\")\n return seed\n\nclass DataLoaderTensorflow(BaseDataLoader):\n \"\"\"Tensorflow Dataloader\"\"\"\n \n @typecheck\n def __init__(\n self, \n dataset: Union[JAXDataset, TFDataset, HFDataset],\n batch_size: int = 1, # Batch size\n shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n drop_last: bool = False, # Drop last batch or not\n generator: Optional[GeneratorType] = None, # Random seed generator\n collate_fn: Optional[Callable] = None, # Function to collate samples into batches\n **kwargs\n ):\n super().__init__(dataset, batch_size, shuffle, drop_last)\n check_tf_installed()\n # get random seed from generator\n seed = get_seed(generator)\n\n # Convert to tf dataset\n ds = to_tf_dataset(dataset)\n ds = ds.shuffle(buffer_size=len(dataset), seed=seed) if shuffle else ds\n ds = ds.batch(batch_size, drop_remainder=drop_last)\n \n # Apply collate_fn if provided\n if collate_fn is not None:\n # TensorFlow map unpacks arguments, so we need a wrapper that packs them back\n def tf_collate_wrapper(*args):\n # Pack arguments back into a tuple to match expected collate_fn interface\n if len(args) == 1:\n batch = args[0]\n else:\n batch = args # tuple of (X, y, ...)\n result = collate_fn(batch)\n # Ensure result is unpacked for TensorFlow\n if isinstance(result, (tuple, list)):\n return result\n else:\n return (result,)\n \n ds = ds.map(tf_collate_wrapper, num_parallel_calls=tf.data.AUTOTUNE)\n \n ds = ds.prefetch(tf.data.AUTOTUNE)\n self.dataloader = ds\n\n def __len__(self):\n return len(self.dataloader)\n\n def __next__(self):\n return next(self.dataloader)\n\n def __iter__(self):\n return self.dataloader.as_numpy_iterator()" }, { "cell_type": "code", diff --git a/nbs/loader.torch.ipynb b/nbs/loader.torch.ipynb index 1b9a997..15f392c 100644 --- a/nbs/loader.torch.ipynb +++ b/nbs/loader.torch.ipynb @@ -99,69 +99,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class DataLoaderPytorch(BaseDataLoader):\n", - " \"\"\"Pytorch Dataloader\"\"\"\n", - " \n", - " @typecheck\n", - " def __init__(\n", - " self, \n", - " dataset: Union[JAXDataset, TorchDataset, HFDataset],\n", - " batch_size: int = 1, # Batch size\n", - " shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n", - " drop_last: bool = False, # Drop last batch or not\n", - " generator: Optional[GeneratorType] = None,\n", - " **kwargs\n", - " ):\n", - " super().__init__(dataset, batch_size, shuffle, drop_last)\n", - " check_pytorch_installed()\n", - " from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler\n", - " import torch\n", - "\n", - " if 'sampler' in kwargs:\n", - " warnings.warn(\"`sampler` is currently not supported. We will ignore it and use `shuffle` instead.\")\n", - " del kwargs['sampler']\n", - "\n", - " # convert to torch dataset\n", - " dataset = to_torch_dataset(dataset)\n", - " # init generator\n", - " if generator is None:\n", - " # explicitly set the manual seed of the generator\n", - " generator = Generator().manual_seed(get_config().global_seed)\n", - " if not isinstance(generator, Generator):\n", - " generator = Generator(generator=generator)\n", - " \n", - " generator = generator.torch_generator()\n", - " # init batch sampler\n", - " if shuffle: \n", - " sampler = RandomSampler(dataset, generator=generator)\n", - " else: \n", - " sampler = SequentialSampler(dataset)\n", - " batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)\n", - "\n", - " self.dataloader = torch_data.DataLoader(\n", - " dataset, \n", - " batch_sampler=batch_sampler,\n", - " # batch_size=batch_size, \n", - " # shuffle=shuffle, \n", - " # drop_last=drop_last,\n", - " collate_fn=_numpy_collate,\n", - " **kwargs\n", - " )\n", - "\n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - "\n", - " def __next__(self):\n", - " return next(self.dataloader)\n", - "\n", - " def __iter__(self):\n", - " return self.dataloader.__iter__()" - ] + "source": "#| export\nclass DataLoaderPytorch(BaseDataLoader):\n \"\"\"Pytorch Dataloader\"\"\"\n \n @typecheck\n def __init__(\n self, \n dataset: Union[JAXDataset, TorchDataset, HFDataset],\n batch_size: int = 1, # Batch size\n shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n drop_last: bool = False, # Drop last batch or not\n generator: Optional[GeneratorType] = None,\n collate_fn: Optional[Callable] = None, # Function to collate samples into batches\n **kwargs\n ):\n super().__init__(dataset, batch_size, shuffle, drop_last)\n check_pytorch_installed()\n from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler\n import torch\n\n if 'sampler' in kwargs:\n warnings.warn(\"`sampler` is currently not supported. We will ignore it and use `shuffle` instead.\")\n del kwargs['sampler']\n\n # convert to torch dataset\n dataset = to_torch_dataset(dataset)\n # init generator\n if generator is None:\n # explicitly set the manual seed of the generator\n generator = Generator().manual_seed(get_config().global_seed)\n if not isinstance(generator, Generator):\n generator = Generator(generator=generator)\n \n generator = generator.torch_generator()\n # init batch sampler\n if shuffle: \n sampler = RandomSampler(dataset, generator=generator)\n else: \n sampler = SequentialSampler(dataset)\n batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)\n\n # Use custom collate_fn if provided, otherwise use default numpy collate\n if collate_fn is None:\n collate_fn = _numpy_collate\n\n self.dataloader = torch_data.DataLoader(\n dataset, \n batch_sampler=batch_sampler,\n # batch_size=batch_size, \n # shuffle=shuffle, \n # drop_last=drop_last,\n collate_fn=collate_fn,\n **kwargs\n )\n\n def __len__(self):\n return len(self.dataloader)\n\n def __next__(self):\n return next(self.dataloader)\n\n def __iter__(self):\n return self.dataloader.__iter__()" }, { "cell_type": "code", diff --git a/nbs/tests.ipynb b/nbs/tests.ipynb index 1395d30..d47c37c 100644 --- a/nbs/tests.ipynb +++ b/nbs/tests.ipynb @@ -33,7 +33,8 @@ "from __future__ import print_function, division, annotations\n", "from jax_dataloader.imports import *\n", "from jax_dataloader.datasets import ArrayDataset\n", - "import jax_dataloader as jdl" + "import jax_dataloader as jdl\n", + "from jax.tree_util import tree_map" ] }, { @@ -144,6 +145,12 @@ " assert len(_X) == len(X_list) * batch_size" ] }, + { + "cell_type": "code", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef test_collate_fn(cls, ds, batch_size: int):\n \"\"\"Test that collate_fn parameter works correctly\"\"\"\n \n def custom_collate(batch):\n if isinstance(batch, dict):\n # HuggingFace format (already batched)\n return {'feats': batch['feats'] + 1.0, 'labels': batch['labels']}\n elif isinstance(batch, list):\n # PyTorch format: list of individual samples\n if len(batch) > 0:\n if isinstance(batch[0], dict):\n # List of dictionaries (HuggingFace with PyTorch backend)\n # Convert to batched dict format\n keys = batch[0].keys()\n result = {}\n for key in keys:\n values = [item[key] for item in batch]\n if key == 'feats':\n result[key] = np.stack(values) + 1.0\n else:\n result[key] = np.array(values)\n return result\n elif isinstance(batch[0], tuple):\n # List of tuples: [(x1, y1), (x2, y2), ...]\n X_list, y_list = zip(*batch)\n X = np.stack(X_list)\n y = np.array(y_list)\n return X + 1.0, y\n else:\n # List of individual arrays\n return np.array(batch) + 1.0\n elif isinstance(batch, tuple):\n # JAX/TF format: already batched tuple (X, y)\n X, y = batch\n if isinstance(X, torch.Tensor):\n X, y = tree_map(np.asarray, (X, y))\n return X + 1.0, y\n else:\n # Single array - already batched\n return batch + 1.0\n\n # Test without collate_fn (baseline)\n dl_normal = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False)\n first_batch_normal = next(iter(dl_normal))\n \n # Test with collate_fn\n dl_collate = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=custom_collate)\n first_batch_collate = next(iter(dl_collate))\n \n # Verify collate_fn was applied\n if isinstance(first_batch_normal, dict):\n # HuggingFace format\n normal_feats = first_batch_normal['feats']\n collate_feats = first_batch_collate['feats']\n assert np.array_equal(collate_feats, normal_feats + 1.0), \"collate_fn should add 1.0 to features\"\n assert np.array_equal(first_batch_collate['labels'], first_batch_normal['labels']), \"labels should be unchanged\"\n else:\n # Tuple format\n normal_X, normal_y = first_batch_normal\n collate_X, collate_y = first_batch_collate\n assert np.array_equal(collate_X, normal_X + 1.0), \"collate_fn should add 1.0 to features\"\n assert np.array_equal(collate_y, normal_y), \"labels should be unchanged\"" + }, { "cell_type": "code", "execution_count": null, @@ -203,7 +210,8 @@ " test_no_shuffle_drop_last(cls, ds, batch_size, feats, labels)\n", " test_shuffle(cls, ds, batch_size, feats, labels)\n", " test_shuffle_drop_last(cls, ds, batch_size, feats, labels)\n", - " test_shuffle_reproducible(cls, ds, batch_size, feats, labels)" + " test_shuffle_reproducible(cls, ds, batch_size, feats, labels)\n", + " test_collate_fn(cls, ds, batch_size)" ] }, {