Skip to content

Commit ccd78ff

Browse files
authored
Warn if the input dtype to TimeStretch is not complex (#3695)
Addresses #3688
1 parent 172260f commit ccd78ff

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/torchaudio/transforms/_transforms.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,13 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
10661066
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
10671067
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
10681068
"""
1069+
if not torch.is_complex(complex_specgrams):
1070+
warnings.warn(
1071+
"The input to TimeStretch must be complex type. "
1072+
"Providing non-complex tensor produces invalid results.",
1073+
stacklevel=4,
1074+
)
1075+
10691076
if overriding_rate is None:
10701077
if self.fixed_rate is None:
10711078
raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")

0 commit comments

Comments
 (0)