|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2020 The TensorFlow Datasets Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +# Lint as: python3 |
| 17 | +"""Image visualizer.""" |
| 18 | + |
| 19 | +from __future__ import absolute_import |
| 20 | +from __future__ import division |
| 21 | +from __future__ import print_function |
| 22 | + |
| 23 | +from absl import logging |
| 24 | + |
| 25 | +from tensorflow_datasets.core import dataset_utils |
| 26 | +from tensorflow_datasets.core import features as features_lib |
| 27 | +from tensorflow_datasets.core import lazy_imports_lib |
| 28 | +from tensorflow_datasets.core.visualization import visualizer |
| 29 | + |
| 30 | + |
| 31 | +def _make_grid(plot_single_ex_fn, ds, rows, cols, plot_scale): |
| 32 | + """Plot each individual example in a grid. |
| 33 | +
|
| 34 | + Args: |
| 35 | + plot_single_ex_fn: Function with fill a single cell of the grid, with |
| 36 | + signature `fn(ax: matplotlib.axes.Axes, ex: Nested[np.array]) -> None` |
| 37 | + ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples |
| 38 | + should not be batched. Examples will be consumed in order until |
| 39 | + (rows * cols) are read or the dataset is consumed. |
| 40 | + rows: `int`, number of rows of the display grid. |
| 41 | + cols: `int`, number of columns of the display grid. |
| 42 | + plot_scale: `float`, controls the plot size of the images. Keep this |
| 43 | + value around 3 to get a good plot. High and low values may cause |
| 44 | + the labels to get overlapped. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + fig: Figure to display. |
| 48 | + """ |
| 49 | + plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot |
| 50 | + |
| 51 | + num_examples = rows * cols |
| 52 | + examples = list(dataset_utils.as_numpy(ds.take(num_examples))) |
| 53 | + |
| 54 | + fig = plt.figure(figsize=(plot_scale * cols, plot_scale * rows)) |
| 55 | + fig.subplots_adjust(hspace=1 / plot_scale, wspace=1 / plot_scale) |
| 56 | + |
| 57 | + for i, ex in enumerate(examples): |
| 58 | + ax = fig.add_subplot(rows, cols, i+1) |
| 59 | + plot_single_ex_fn(ax, ex) |
| 60 | + |
| 61 | + plt.show() |
| 62 | + return fig |
| 63 | + |
| 64 | + |
| 65 | +def _add_image(ax, image): |
| 66 | + """Add the image to the given `matplotlib.axes.Axes`.""" |
| 67 | + plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot |
| 68 | + |
| 69 | + if len(image.shape) != 3: |
| 70 | + raise ValueError( |
| 71 | + 'Image dimension should be 3. tfds.show_examples does not support ' |
| 72 | + 'batched examples or video.') |
| 73 | + _, _, c = image.shape |
| 74 | + if c == 1: |
| 75 | + image = image.reshape(image.shape[:2]) |
| 76 | + ax.imshow(image, cmap='gray') |
| 77 | + ax.grid(False) |
| 78 | + plt.xticks([], []) |
| 79 | + plt.yticks([], []) |
| 80 | + |
| 81 | + |
| 82 | +class ImageGridVisualizer(visualizer.Visualizer): |
| 83 | + """Visualizer for supervised image datasets.""" |
| 84 | + |
| 85 | + def match(self, ds_info): |
| 86 | + """See base class.""" |
| 87 | + # Supervised required a single image key |
| 88 | + image_keys = visualizer.extract_keys(ds_info.features, features_lib.Image) |
| 89 | + return len(image_keys) >= 1 |
| 90 | + |
| 91 | + def show( |
| 92 | + self, |
| 93 | + ds_info, |
| 94 | + ds, |
| 95 | + rows=3, |
| 96 | + cols=3, |
| 97 | + plot_scale=3., |
| 98 | + image_key=None, |
| 99 | + ): |
| 100 | + """Display the dataset. |
| 101 | +
|
| 102 | + Args: |
| 103 | + ds_info: `tfds.core.DatasetInfo` object of the dataset to visualize. |
| 104 | + ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples |
| 105 | + should not be batched. Examples will be consumed in order until |
| 106 | + (rows * cols) are read or the dataset is consumed. |
| 107 | + rows: `int`, number of rows of the display grid. |
| 108 | + cols: `int`, number of columns of the display grid. |
| 109 | + plot_scale: `float`, controls the plot size of the images. Keep this |
| 110 | + value around 3 to get a good plot. High and low values may cause |
| 111 | + the labels to get overlapped. |
| 112 | + image_key: `string`, name of the feature that contains the image. If not |
| 113 | + set, the system will try to auto-detect it. |
| 114 | + """ |
| 115 | + # Extract the image key |
| 116 | + if not image_key: |
| 117 | + image_keys = visualizer.extract_keys(ds_info.features, features_lib.Image) |
| 118 | + if len(image_keys) > 1: |
| 119 | + raise ValueError( |
| 120 | + 'Multiple image features detected in the dataset. ' |
| 121 | + 'Use `image_key` argument to override. Images detected: {}'.format( |
| 122 | + image_keys)) |
| 123 | + image_key = image_keys[0] |
| 124 | + |
| 125 | + # Optionally extract the label key |
| 126 | + label_keys = visualizer.extract_keys( |
| 127 | + ds_info.features, features_lib.ClassLabel) |
| 128 | + label_key = label_keys[0] if len(label_keys) == 1 else None |
| 129 | + if not label_key: |
| 130 | + logging.info('Was not able to auto-infer label.') |
| 131 | + |
| 132 | + # Single image display |
| 133 | + def make_cell_fn(ax, ex): |
| 134 | + plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot |
| 135 | + |
| 136 | + if not isinstance(ex, dict): |
| 137 | + raise ValueError( |
| 138 | + '{} requires examples as `dict`, with the same ' |
| 139 | + 'structure as `ds_info.features`. It is currently not compatible ' |
| 140 | + 'with `as_supervised=True`. Received: {}'.format( |
| 141 | + type(self).__name__, type(ex))) |
| 142 | + |
| 143 | + _add_image(ax, ex[image_key]) |
| 144 | + if label_key: |
| 145 | + label = ex[label_key] |
| 146 | + label_str = ds_info.features[label_key].int2str(label) |
| 147 | + plt.xlabel('{} ({})'.format(label_str, label)) |
| 148 | + |
| 149 | + # Print the grid |
| 150 | + fig = _make_grid(make_cell_fn, ds, rows, cols, plot_scale) |
0 commit comments