diff --git a/tensorflow_datasets/core/visualization/__init__.py b/tensorflow_datasets/core/visualization/__init__.py index 4d87d5443da..49270dd14b0 100644 --- a/tensorflow_datasets/core/visualization/__init__.py +++ b/tensorflow_datasets/core/visualization/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. """Visualizer utils.""" - +from tensorflow_datasets.core.visualization.audio_visualizer import AudioGridVisualizer from tensorflow_datasets.core.visualization.image_visualizer import ImageGridVisualizer from tensorflow_datasets.core.visualization.show_examples import show_examples from tensorflow_datasets.core.visualization.visualizer import Visualizer @@ -22,6 +22,7 @@ __all__ = [ "ImageGridVisualizer", + "AudioGridVisualizer", "show_examples", "Visualizer", ] diff --git a/tensorflow_datasets/core/visualization/audio_visualizer.py b/tensorflow_datasets/core/visualization/audio_visualizer.py new file mode 100644 index 00000000000..769fbdabc01 --- /dev/null +++ b/tensorflow_datasets/core/visualization/audio_visualizer.py @@ -0,0 +1,87 @@ +""" Audio Visualizer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow_datasets.core import dataset_utils +from tensorflow_datasets.core import features as features_lib +from tensorflow_datasets.core import lazy_imports_lib +from tensorflow_datasets.core.visualization import visualizer + +def _make_audio_grid(ds, key, samplerate, rows, cols, plot_scale): + """Plot the waveforms and IPython objects of some samples of the argument audio dataset + + Args: + ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. + key: The inferred key for the dataset + samplerate : Inferred samplerate of the dataset. + rows: `int`, number of rows of the display grid. + cols: `int`, number of columns of the display grid. + plot_scale: `float`, controls the plot size of the images. Keep this + value around 3 to get a good plot. High and low values may cause + the labels to get overlapped. + Returns: + fig: Waveform figure to display. IPython objects are not returned. + """ + import IPython.display as ipd + plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot + + num_examples = rows * cols + examples = list(dataset_utils.as_numpy(ds.take(num_examples))) + + fig = plt.figure(figsize=(plot_scale * cols, plot_scale * rows)) + fig.subplots_adjust(hspace=1 / plot_scale, wspace=1 / plot_scale) + t1 = 0 + t2 = 100 * 1000 + + for i, ex in enumerate(examples): + ax = fig.add_subplot(rows, cols, i+1) + ax.plot(ex[key]) + audio = ex['audio'] + newaudio = audio[t1:t2] + ipd.display(ipd.Audio(newaudio, rate=samplerate)) + + + plt.show() + return fig + + +class AudioGridVisualizer(visualizer.Visualizer): + """ Fixed grid Visualizer for audio datasets.""" + def match(self, ds_info): + """ See base class.""" + audio_keys = visualizer.extract_keys(ds_info.features, features_lib.Audio) + return len(audio_keys) > 0 + + def show( + self, + ds_info, + ds, + rows=2, + cols=2, + plot_scale=3., + audio_key=None, + ): + """Display the audio dataset. + + Args: + ds_info: `tfds.core.DatasetInfo` object of the dataset to visualize. + ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. + rows: `int`, number of rows of the display grid : Default is 2. + cols: `int`, number of columns of the display grid : Default is 2. + plot_scale: `float`, controls the plot size of the images. Keep this + value around 3 to get a good plot. High and low values may cause + the labels to get overlapped. + audio_key: `string`, name of the feature that contains the audio. If not + set, the system will try to auto-detect it. + """ + if not audio_key: + #Auto inferring the audio key + audio_keys = visualizer.extract_keys(ds_info.features, features_lib.Audio) + key = audio_keys[0] + # Identifying the sample rate If None - 16000KHz is used as default + samplerate = ds_info.features[key].sample_rate + if not samplerate: + samplerate = 16000 + _make_audio_grid(ds, key, samplerate, rows, cols, plot_scale) + diff --git a/tensorflow_datasets/core/visualization/show_examples.py b/tensorflow_datasets/core/visualization/show_examples.py index 051aeafb4a3..7c4543633ea 100644 --- a/tensorflow_datasets/core/visualization/show_examples.py +++ b/tensorflow_datasets/core/visualization/show_examples.py @@ -21,10 +21,12 @@ from __future__ import division from __future__ import print_function +from tensorflow_datasets.core.visualization import audio_visualizer from tensorflow_datasets.core.visualization import image_visualizer _ALL_VISUALIZERS = [ image_visualizer.ImageGridVisualizer(), + audio_visualizer.AudioGridVisualizer(), ] diff --git a/tensorflow_datasets/core/visualization/show_examples_test.py b/tensorflow_datasets/core/visualization/show_examples_test.py index 2d8488bec12..f2a75844691 100644 --- a/tensorflow_datasets/core/visualization/show_examples_test.py +++ b/tensorflow_datasets/core/visualization/show_examples_test.py @@ -28,17 +28,22 @@ # Import for registration from tensorflow_datasets.image_classification import imagenet # pylint: disable=unused-import,g-bad-import-order - +from tensorflow_datasets.audio import crema_d +from tensorflow_datasets.core.visualization import image_visualizer class ShowExamplesTest(testing.TestCase): - @mock.patch('matplotlib.pyplot.figure') - def test_show_examples(self, mock_fig): + @mock.patch('audio_visualizer.AudioGridVisualizer') + def test_show_examples(self,mock_fig): with testing.mock_data(num_examples=20): ds, ds_info = registered.load( 'imagenet2012', split='train', with_info=True) - visualization.show_examples(ds_info, ds) - + visualization.show_examples(ds_info, ds) + + ds, ds_info = registered.load( + 'crema_d', split='validation', with_info=True) + visualization.show_examples(ds_info, ds) + # TODO(tfds): Should add test when there isn't enough examples (ds.take(3))