Skip to content

Centroid frequency limits #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def test_batch_spectral_centroid(self):

self.assert_batch_consistency(transform, waveform)

def test_batch_spectral_centroid_with_range(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=6)
waveform = waveform.reshape(3, 2, -1)
transform = T.SpectralCentroid(sample_rate, min_freq=300., max_freq=500.)
self.assert_batch_consistency(transform, waveform)

def test_batch_pitch_shift(self):
sample_rate = 8000
n_steps = -2
Expand Down
21 changes: 19 additions & 2 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,12 +1140,16 @@ def spectral_centroid(
n_fft: int,
hop_length: int,
win_length: int,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
) -> Tensor:
r"""
Compute the spectral centroid for each channel along the time axis.

The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
frequency values, weighted by their magnitude.
Optionally find centroid of a limited range of the spectrum, specified by the
optional min_freq and max_dreq arguments.

Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
Expand All @@ -1155,14 +1159,27 @@ def spectral_centroid(
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
min_freq (float, optional): Specify a minimum frequency to include in centroid calculation
max_freq (float, optional): Specify a maximum frequency to include in centroid calculation

Returns:
Tensor: Dimension `(..., time)`
"""
nyquist = sample_rate // 2
fft_bins = 1 + n_fft // 2
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, power=1., normalized=False)
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
freqs = torch.linspace(0, nyquist, steps=fft_bins,
device=specgram.device).reshape((-1, 1))

min_freq_index = int(round((min_freq * fft_bins) / nyquist)) if min_freq is not None else 0
max_freq_index = int(round((max_freq * fft_bins) / nyquist)) if max_freq is not None else fft_bins
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC TorchScript does not handle ternary operator. Can you run pytest test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py -k Centroid and see if it works?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this test seems to pass on my system. Should I still swap out the ternary?


if min_freq is not None or max_freq is not None:
assert min_freq_index < max_freq_index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking the value here, the check on the given arguments should happen at the beginning of the argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't fully understand this. Would you mind explaining?

specgram = specgram[...,min_freq_index:max_freq_index,:]
freqs = freqs[...,min_freq_index:max_freq_index,:]

freq_dim = -2
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)

Expand Down
8 changes: 7 additions & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,8 @@ class SpectralCentroid(torch.nn.Module):
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
min_freq (float, optional): Specify a minimum frequency to include in centroid calculation
max_freq (float, optional): Specify a maximum frequency to include in centroid calculation

Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
Expand All @@ -1395,6 +1397,8 @@ def __init__(self,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
wkwargs: Optional[dict] = None) -> None:
super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate
Expand All @@ -1404,6 +1408,8 @@ def __init__(self,
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
self.min_freq = min_freq
self.max_freq = max_freq

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -1415,7 +1421,7 @@ def forward(self, waveform: Tensor) -> Tensor:
"""

return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
self.win_length)
self.win_length, self.min_freq, self.max_freq)


class PitchShift(torch.nn.Module):
Expand Down