File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change 1515
1616"""
1717import types
18- from typing import Callable , Iterator
18+ from typing import Callable , Iterator , Optional
1919
2020import jax
2121import jax .numpy as jnp
@@ -53,6 +53,8 @@ def mnist_dataset(
5353 multi_device : bool = True ,
5454 reshuffle_each_iteration : bool = True ,
5555 dtype : str = "float32" ,
56+ take : Optional [int ] = None ,
57+ skip : Optional [int ] = None ,
5658) -> Iterator [Batch ]:
5759 """Standard MNIST dataset pipeline.
5860
@@ -72,6 +74,9 @@ def mnist_dataset(
7274 reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
7375 after each iteration.
7476 dtype: The returned data type of the images.
77+ take: If specified, will take the first `take` examples after skipping
78+ `skip` examples.
79+ skip: If specified, will skip the first `skip` examples.
7580
7681 Returns:
7782 The MNIST dataset as a tensorflow dataset.
@@ -110,6 +115,12 @@ def preprocess_batch(
110115
111116 ds = tfds .load (name = "mnist" , split = split , as_supervised = True )
112117
118+ if skip is not None :
119+ ds = ds .skip (skip )
120+
121+ if take is not None :
122+ ds = ds .take (take )
123+
113124 ds = ds .shard (jax .process_count (), jax .process_index ())
114125
115126 ds = ds .cache ()
You can’t perform that action at this time.
0 commit comments