Skip to content

Commit 486f896

Browse files
timothyn617KfacJaxDev
authored andcommitted
Enable skip and take of MNIST dataset.
PiperOrigin-RevId: 748714721
1 parent 3241b5f commit 486f896

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

examples/datasets.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
1616
"""
1717
import types
18-
from typing import Callable, Iterator
18+
from typing import Callable, Iterator, Optional
1919

2020
import jax
2121
import jax.numpy as jnp
@@ -53,6 +53,8 @@ def mnist_dataset(
5353
multi_device: bool = True,
5454
reshuffle_each_iteration: bool = True,
5555
dtype: str = "float32",
56+
take: Optional[int] = None,
57+
skip: Optional[int] = None,
5658
) -> Iterator[Batch]:
5759
"""Standard MNIST dataset pipeline.
5860
@@ -72,6 +74,9 @@ def mnist_dataset(
7274
reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
7375
after each iteration.
7476
dtype: The returned data type of the images.
77+
take: If specified, will take the first `take` examples after skipping
78+
`skip` examples.
79+
skip: If specified, will skip the first `skip` examples.
7580
7681
Returns:
7782
The MNIST dataset as a tensorflow dataset.
@@ -110,6 +115,12 @@ def preprocess_batch(
110115

111116
ds = tfds.load(name="mnist", split=split, as_supervised=True)
112117

118+
if skip is not None:
119+
ds = ds.skip(skip)
120+
121+
if take is not None:
122+
ds = ds.take(take)
123+
113124
ds = ds.shard(jax.process_count(), jax.process_index())
114125

115126
ds = ds.cache()

0 commit comments

Comments
 (0)