Skip to content

Commit cbfc20e

Browse files
Conchylicultorcopybara-github
authored andcommitted
Refactor visualisation util
PiperOrigin-RevId: 303010840
1 parent cd799fb commit cbfc20e

File tree

7 files changed

+327
-139
lines changed

7 files changed

+327
-139
lines changed

tensorflow_datasets/core/visualization.py

Lines changed: 0 additions & 130 deletions
This file was deleted.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# coding=utf-8
2+
# Copyright 2020 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Visualizer utils."""
17+
18+
from tensorflow_datasets.core.visualization.image_visualizer import ImageGridVisualizer
19+
from tensorflow_datasets.core.visualization.show_examples import show_examples
20+
from tensorflow_datasets.core.visualization.visualizer import Visualizer
21+
22+
23+
__all__ = [
24+
"ImageGridVisualizer",
25+
"show_examples",
26+
"Visualizer",
27+
]
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# coding=utf-8
2+
# Copyright 2020 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""Image visualizer."""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
from absl import logging
24+
25+
from tensorflow_datasets.core import dataset_utils
26+
from tensorflow_datasets.core import features as features_lib
27+
from tensorflow_datasets.core import lazy_imports_lib
28+
from tensorflow_datasets.core.visualization import visualizer
29+
30+
31+
def _make_grid(plot_single_ex_fn, ds, rows, cols, plot_scale):
32+
"""Plot each individual example in a grid.
33+
34+
Args:
35+
plot_single_ex_fn: Function with fill a single cell of the grid, with
36+
signature `fn(ax: matplotlib.axes.Axes, ex: Nested[np.array]) -> None`
37+
ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples
38+
should not be batched. Examples will be consumed in order until
39+
(rows * cols) are read or the dataset is consumed.
40+
rows: `int`, number of rows of the display grid.
41+
cols: `int`, number of columns of the display grid.
42+
plot_scale: `float`, controls the plot size of the images. Keep this
43+
value around 3 to get a good plot. High and low values may cause
44+
the labels to get overlapped.
45+
46+
Returns:
47+
fig: Figure to display.
48+
"""
49+
plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot
50+
51+
num_examples = rows * cols
52+
examples = list(dataset_utils.as_numpy(ds.take(num_examples)))
53+
54+
fig = plt.figure(figsize=(plot_scale * cols, plot_scale * rows))
55+
fig.subplots_adjust(hspace=1 / plot_scale, wspace=1 / plot_scale)
56+
57+
for i, ex in enumerate(examples):
58+
ax = fig.add_subplot(rows, cols, i+1)
59+
plot_single_ex_fn(ax, ex)
60+
61+
plt.show()
62+
return fig
63+
64+
65+
def _add_image(ax, image):
66+
"""Add the image to the given `matplotlib.axes.Axes`."""
67+
plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot
68+
69+
if len(image.shape) != 3:
70+
raise ValueError(
71+
'Image dimension should be 3. tfds.show_examples does not support '
72+
'batched examples or video.')
73+
_, _, c = image.shape
74+
if c == 1:
75+
image = image.reshape(image.shape[:2])
76+
ax.imshow(image, cmap='gray')
77+
ax.grid(False)
78+
plt.xticks([], [])
79+
plt.yticks([], [])
80+
81+
82+
class ImageGridVisualizer(visualizer.Visualizer):
83+
"""Visualizer for supervised image datasets."""
84+
85+
def match(self, ds_info):
86+
"""See base class."""
87+
# Supervised required a single image key
88+
image_keys = visualizer.extract_keys(ds_info.features, features_lib.Image)
89+
return len(image_keys) >= 1
90+
91+
def show(
92+
self,
93+
ds_info,
94+
ds,
95+
rows=3,
96+
cols=3,
97+
plot_scale=3.,
98+
image_key=None,
99+
):
100+
"""Display the dataset.
101+
102+
Args:
103+
ds_info: `tfds.core.DatasetInfo` object of the dataset to visualize.
104+
ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples
105+
should not be batched. Examples will be consumed in order until
106+
(rows * cols) are read or the dataset is consumed.
107+
rows: `int`, number of rows of the display grid.
108+
cols: `int`, number of columns of the display grid.
109+
plot_scale: `float`, controls the plot size of the images. Keep this
110+
value around 3 to get a good plot. High and low values may cause
111+
the labels to get overlapped.
112+
image_key: `string`, name of the feature that contains the image. If not
113+
set, the system will try to auto-detect it.
114+
"""
115+
# Extract the image key
116+
if not image_key:
117+
image_keys = visualizer.extract_keys(ds_info.features, features_lib.Image)
118+
if len(image_keys) > 1:
119+
raise ValueError(
120+
'Multiple image features detected in the dataset. '
121+
'Use `image_key` argument to override. Images detected: {}'.format(
122+
image_keys))
123+
image_key = image_keys[0]
124+
125+
# Optionally extract the label key
126+
label_keys = visualizer.extract_keys(
127+
ds_info.features, features_lib.ClassLabel)
128+
label_key = label_keys[0] if len(label_keys) == 1 else None
129+
if not label_key:
130+
logging.info('Was not able to auto-infer label.')
131+
132+
# Single image display
133+
def make_cell_fn(ax, ex):
134+
plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot
135+
136+
if not isinstance(ex, dict):
137+
raise ValueError(
138+
'{} requires examples as `dict`, with the same '
139+
'structure as `ds_info.features`. It is currently not compatible '
140+
'with `as_supervised=True`. Received: {}'.format(
141+
type(self).__name__, type(ex)))
142+
143+
_add_image(ax, ex[image_key])
144+
if label_key:
145+
label = ex[label_key]
146+
label_str = ds_info.features[label_key].int2str(label)
147+
plt.xlabel('{} ({})'.format(label_str, label))
148+
149+
# Print the grid
150+
fig = _make_grid(make_cell_fn, ds, rows, cols, plot_scale)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
# Copyright 2020 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""Show example util.
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
from tensorflow_datasets.core.visualization import image_visualizer
25+
26+
_ALL_VISUALIZERS = [
27+
image_visualizer.ImageGridVisualizer(),
28+
]
29+
30+
31+
def show_examples(ds_info, ds, **options_kwargs):
32+
"""Visualize images (and labels) from an image classification dataset.
33+
34+
This function is for interactive use (Colab, Jupyter). It displays and return
35+
a plot of (rows*columns) images from a tf.data.Dataset.
36+
37+
Usage:
38+
```python
39+
ds, ds_info = tfds.load('cifar10', split='train', with_info=True)
40+
fig = tfds.show_examples(ds_info, ds)
41+
```
42+
43+
Args:
44+
ds_info: The dataset info object to which extract the label and features
45+
info. Available either through `tfds.load('mnist', with_info=True)` or
46+
`tfds.builder('mnist').info`
47+
ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples
48+
should not be batched. Examples will be consumed in order until
49+
(rows * cols) are read or the dataset is consumed.
50+
**options_kwargs: Additional display options, specific to the dataset type
51+
to visualize. Are forwarded to `tfds.visualization.Visualizer.show`.
52+
See the `tfds.visualization` for a list of available visualizers.
53+
54+
Returns:
55+
fig: The `matplotlib.Figure` object
56+
"""
57+
for visualizer in _ALL_VISUALIZERS:
58+
if visualizer.match(ds_info):
59+
visualizer.show(ds_info, ds, **options_kwargs)
60+
break
61+
else:
62+
raise ValueError(
63+
"Visualisation not supported for dataset `{}`".format(ds_info.name)
64+
)

0 commit comments

Comments
 (0)