From 28e36a20745719c5c7404c8e9dd2916e7827b996 Mon Sep 17 00:00:00 2001 From: jjw Date: Fri, 17 Sep 2021 22:09:53 +0100 Subject: [PATCH 1/4] first attempt at extending spectral centroid to include frequency cropping --- torchaudio/functional/functional.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 7982b0b6ce..f1bbcab143 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1170,12 +1170,16 @@ def spectral_centroid( n_fft: int, hop_length: int, win_length: int, + min_freq: Optional[float] = None, + max_freq: Optional[float] = None, ) -> Tensor: r""" Compute the spectral centroid for each channel along the time axis. The spectral centroid is defined as the weighted average of the - frequency values, weighted by their magnitude. + frequency values, weighted by their magnitude. + Optionally find centroid of a limited range of the spectrum, specified by the + optional min_freq and max_dreq arguments. Args: waveform (Tensor): Tensor of audio of dimension (..., time) @@ -1185,14 +1189,27 @@ def spectral_centroid( n_fft (int): Size of FFT hop_length (int): Length of hop between STFT windows win_length (int): Window size + min_freq (float, optional): Specify a minimum frequency to include in centroid calculation + max_freq (float, optional): Specify a maximum frequency to include in centroid calculation Returns: Tensor: Dimension (..., time) """ + nyquist = sample_rate // 2 + fft_bins = 1 + n_fft // 2 specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, win_length=win_length, power=1., normalized=False) freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1)) + + min_freq_index = int(round((min_freq * fft_bins) / nyquist)) if min_freq is not None else 0 + max_freq_index = int(round((max_freq * fft_bins) / nyquist)) if max_freq is not None else fft_bins + + if min_freq is not None or max_freq is not None: + assert min_freq_index < max_freq_index + specgram = specgram[...,min_freq_index:max_freq_index,:] + freqs = freqs[...,min_freq_index:max_freq_index,:] + freq_dim = -2 return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) From aefe1a7bff7c99ee66ead71e7049b218eea6eb92 Mon Sep 17 00:00:00 2001 From: jacob Date: Mon, 6 Dec 2021 12:52:25 +0000 Subject: [PATCH 2/4] Update torchaudio/functional/functional.py Use variables to make code more readable Co-authored-by: moto <855818+mthrok@users.noreply.github.com> --- torchaudio/functional/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index ce47722029..c7dba94076 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1169,7 +1169,7 @@ def spectral_centroid( fft_bins = 1 + n_fft // 2 specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, win_length=win_length, power=1., normalized=False) - freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, + freqs = torch.linspace(0, nyquist, steps=fft_bins, device=specgram.device).reshape((-1, 1)) min_freq_index = int(round((min_freq * fft_bins) / nyquist)) if min_freq is not None else 0 From 514fa66d363f8a1a030c059cffb41a89f3e7811e Mon Sep 17 00:00:00 2001 From: jjw Date: Mon, 6 Dec 2021 13:06:58 +0000 Subject: [PATCH 3/4] Added torchaudio transform --- .../transforms/batch_consistency_test.py | 7 +++++++ torchaudio/transforms.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index f1ba7a98eb..073611d821 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -163,6 +163,13 @@ def test_batch_spectral_centroid(self): self.assert_batch_consistency(transform, waveform) + def test_batch_spectral_centroid_with_range(self): + sample_rate = 44100 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.SpectralCentroid(sample_rate, min_freq=300., max_freq=500.) + self.assert_batch_consistency(transform, waveform) + def test_batch_pitch_shift(self): sample_rate = 8000 n_steps = -2 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index bb64a43b14..ec6e4e27bc 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -1395,6 +1395,8 @@ def __init__(self, hop_length: Optional[int] = None, pad: int = 0, window_fn: Callable[..., Tensor] = torch.hann_window, + min_freq: Optional[float] = None, + max_freq: Optional[float] = None, wkwargs: Optional[dict] = None) -> None: super(SpectralCentroid, self).__init__() self.sample_rate = sample_rate @@ -1404,6 +1406,8 @@ def __init__(self, window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) self.register_buffer('window', window) self.pad = pad + self.min_freq = min_freq + self.max_freq = max_freq def forward(self, waveform: Tensor) -> Tensor: r""" @@ -1415,7 +1419,7 @@ def forward(self, waveform: Tensor) -> Tensor: """ return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, - self.win_length) + self.win_length, self.min_freq, self.max_freq) class PitchShift(torch.nn.Module): From 188a40c9b83941a10186b0e2b7408323d0145cad Mon Sep 17 00:00:00 2001 From: jjw Date: Mon, 6 Dec 2021 13:09:35 +0000 Subject: [PATCH 4/4] Added docstring for new params --- torchaudio/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index ec6e4e27bc..a46e3ff0de 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -1380,6 +1380,8 @@ class SpectralCentroid(torch.nn.Module): window_fn (Callable[..., Tensor], optional): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) + min_freq (float, optional): Specify a minimum frequency to include in centroid calculation + max_freq (float, optional): Specify a maximum frequency to include in centroid calculation Example >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)