Skip to content

Commit bc6c38a

Browse files
jpuigcervercopybara-github
authored andcommitted
Add "id" field to the CIFAR datasets, for the user convenience.
This allows to individually identify the dataset examples examples. This does not affect the sharding of the dataset, so the splits are still compatible with the previous version. PiperOrigin-RevId: 302750645
1 parent 88eec14 commit bc6c38a

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

tensorflow_datasets/image/cifar.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ class Cifar10(tfds.core.GeneratorBasedBuilder):
4646
"""CIFAR-10."""
4747

4848
VERSION = tfds.core.Version("3.0.1")
49+
SUPPORTED_VERSIONS = [
50+
tfds.core.Version(
51+
"3.0.2", experiments={tfds.core.Experiment.METADATA: True}
52+
),
53+
]
4954

5055
def _info(self):
5156
return tfds.core.DatasetInfo(
@@ -54,6 +59,7 @@ def _info(self):
5459
"images in 10 classes, with 6000 images per class. There "
5560
"are 50000 training images and 10000 test images."),
5661
features=tfds.features.FeaturesDict({
62+
"id": tfds.features.Text(),
5763
"image": tfds.features.Image(shape=_CIFAR_IMAGE_SHAPE),
5864
"label": tfds.features.ClassLabel(num_classes=10),
5965
}),
@@ -100,19 +106,26 @@ def gen_filenames(filenames):
100106
return [
101107
tfds.core.SplitGenerator(
102108
name=tfds.Split.TRAIN,
103-
gen_kwargs={"filepaths": gen_filenames(cifar_info.train_files)}),
109+
gen_kwargs={
110+
"split_prefix": "train_",
111+
"filepaths": gen_filenames(cifar_info.train_files)
112+
}),
104113
tfds.core.SplitGenerator(
105114
name=tfds.Split.TEST,
106-
gen_kwargs={"filepaths": gen_filenames(cifar_info.test_files)}),
115+
gen_kwargs={
116+
"split_prefix": "test_",
117+
"filepaths": gen_filenames(cifar_info.test_files)
118+
}),
107119
]
108120

109-
def _generate_examples(self, filepaths):
121+
def _generate_examples(self, split_prefix, filepaths):
110122
"""Generate CIFAR examples as dicts.
111123
112124
Shared across CIFAR-{10, 100}. Uses self._cifar_info as
113125
configuration.
114126
115127
Args:
128+
split_prefix (str): Prefix that identifies the split (e.g. "tr" or "te").
116129
filepaths (list[str]): The files to use to generate the data.
117130
118131
Yields:
@@ -123,6 +136,10 @@ def _generate_examples(self, filepaths):
123136
for path in filepaths:
124137
for labels, np_image in _load_data(path, len(label_keys)):
125138
record = dict(zip(label_keys, labels))
139+
# Note: "id" is only provided for the user convenience. To shuffle the
140+
# dataset we use `index`, so that the sharding is compatible with
141+
# earlier versions.
142+
record["id"] = "{}{:05d}".format(split_prefix, index)
126143
record["image"] = np_image
127144
yield index, record
128145
index += 1
@@ -132,6 +149,11 @@ class Cifar100(Cifar10):
132149
"""CIFAR-100 dataset."""
133150

134151
VERSION = tfds.core.Version("3.0.1")
152+
SUPPORTED_VERSIONS = [
153+
tfds.core.Version(
154+
"3.0.2", experiments={tfds.core.Experiment.METADATA: True}
155+
),
156+
]
135157

136158
@property
137159
def _cifar_info(self):
@@ -156,6 +178,7 @@ def _info(self):
156178
"(the class to which it belongs) and a \"coarse\" label "
157179
"(the superclass to which it belongs)."),
158180
features=tfds.features.FeaturesDict({
181+
"id": tfds.features.Text(),
159182
"image": tfds.features.Image(shape=_CIFAR_IMAGE_SHAPE),
160183
"label": tfds.features.ClassLabel(num_classes=100),
161184
"coarse_label": tfds.features.ClassLabel(num_classes=20),

0 commit comments

Comments
 (0)