-
Notifications
You must be signed in to change notification settings - Fork 696
Centroid frequency limits #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28e36a2
1a51a20
aefe1a7
514fa66
984ff6a
188a40c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1140,12 +1140,16 @@ def spectral_centroid( | |
n_fft: int, | ||
hop_length: int, | ||
win_length: int, | ||
min_freq: Optional[float] = None, | ||
max_freq: Optional[float] = None, | ||
) -> Tensor: | ||
r""" | ||
Compute the spectral centroid for each channel along the time axis. | ||
|
||
The spectral centroid is defined as the weighted average of the | ||
frequency values, weighted by their magnitude. | ||
frequency values, weighted by their magnitude. | ||
Optionally find centroid of a limited range of the spectrum, specified by the | ||
optional min_freq and max_dreq arguments. | ||
|
||
Args: | ||
waveform (Tensor): Tensor of audio of dimension `(..., time)` | ||
|
@@ -1155,14 +1159,27 @@ def spectral_centroid( | |
n_fft (int): Size of FFT | ||
hop_length (int): Length of hop between STFT windows | ||
win_length (int): Window size | ||
min_freq (float, optional): Specify a minimum frequency to include in centroid calculation | ||
max_freq (float, optional): Specify a maximum frequency to include in centroid calculation | ||
|
||
Returns: | ||
Tensor: Dimension `(..., time)` | ||
""" | ||
nyquist = sample_rate // 2 | ||
fft_bins = 1 + n_fft // 2 | ||
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, | ||
win_length=win_length, power=1., normalized=False) | ||
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, | ||
freqs = torch.linspace(0, nyquist, steps=fft_bins, | ||
device=specgram.device).reshape((-1, 1)) | ||
|
||
min_freq_index = int(round((min_freq * fft_bins) / nyquist)) if min_freq is not None else 0 | ||
max_freq_index = int(round((max_freq * fft_bins) / nyquist)) if max_freq is not None else fft_bins | ||
|
||
if min_freq is not None or max_freq is not None: | ||
assert min_freq_index < max_freq_index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of checking the value here, the check on the given arguments should happen at the beginning of the argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't fully understand this. Would you mind explaining? |
||
specgram = specgram[...,min_freq_index:max_freq_index,:] | ||
freqs = freqs[...,min_freq_index:max_freq_index,:] | ||
|
||
freq_dim = -2 | ||
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC TorchScript does not handle ternary operator. Can you run
pytest test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py -k Centroid
and see if it works?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, this test seems to pass on my system. Should I still swap out the ternary?