@@ -325,7 +325,7 @@ def _stft(
325
325
We can write STFT in terms of convolutions with a DFT kernel.
326
326
At the end:
327
327
* The real part output is: cos_base * input_real + sin_base * input_imag
328
- * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
328
+ * The imaginary part output is: cos_base * input_imag - sin_base * input_real
329
329
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
330
330
"""
331
331
hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -342,7 +342,7 @@ def _stft(
342
342
343
343
# create a window of centered 1s of the requested size
344
344
if win_length :
345
- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
345
+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
346
346
347
347
# apply time window
348
348
if window :
@@ -358,12 +358,13 @@ def _stft(
358
358
if input_imaginary :
359
359
signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360
360
361
- # conv with DFT kernel across the input signal
362
- # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363
- # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364
- # If x is complex then x[n]=(a+i*b)
365
- # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366
- # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
361
+ # Convolve the DFT kernel with the input signal
362
+ # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
363
+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
364
+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
365
+ # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
366
+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
367
+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
367
368
cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368
369
sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369
370
if input_imaginary :
@@ -372,11 +373,11 @@ def _stft(
372
373
373
374
# add everything together
374
375
if input_imaginary :
375
- real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376
- imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
376
+ real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
377
+ imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377
378
else :
378
379
real_result = cos_windows_real
379
- imag_result = mb . sub ( x = 0. , y = sin_windows_real , before_op = before_op )
380
+ imag_result = sin_windows_real
380
381
381
382
# reduce the rank of the output
382
383
if should_increase_rank :
@@ -417,17 +418,18 @@ def _istft(
417
418
# By default, use the entire frame
418
419
win_length = win_length or n_fft
419
420
420
- input_shape = mb .shape (x = x , before_op = before_op )
421
- n_frames = input_shape .val [- 1 ]
422
- fft_size = input_shape .val [- 2 ]
423
- # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
421
+ input_shape = mb .shape (x = input_real , before_op = before_op )
422
+ channels = input_shape .val [0 ]
423
+ fft_size = input_shape .val [1 ]
424
+ n_frames = input_shape .val [2 ]
425
+ expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
424
426
425
427
is_onesided = onesided .val if onesided else fft_size != n_fft
426
428
cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
427
429
428
430
# create a window of centered 1s of the requested size
429
431
if win_length :
430
- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
432
+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
431
433
432
434
# apply time window
433
435
if window :
@@ -447,14 +449,13 @@ def _istft(
447
449
signal_real = mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448
450
signal_imaginary = mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449
451
450
- # Conv with DFT kernel across the input signal
451
- # We can describe the IDFT in terms of DFT just by swapping the input and output
452
+ # Convolve the DFT kernel with the input signal
453
+ # We can describe the IDFT in terms of DFT just by swapping the input and output.
452
454
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
453
- # So IDFT(x) = (1/N) * swap(DFT(swap(x)))
454
- # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
455
- # If x is complex then x[n]=(a+i*b)
456
- # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
457
- # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
455
+ # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
456
+ # So using the definition in stft function, we get:
457
+ # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
458
+ # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458
459
cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459
460
sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460
461
cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -519,6 +520,7 @@ def _overlap_add(
519
520
def _get_window (
520
521
win_length : Var ,
521
522
n_fft : Var ,
523
+ window : Optional [Var ],
522
524
before_op : Operation ,
523
525
) -> Var :
524
526
n_left = (n_fft .val - win_length .val ) // 2
@@ -750,17 +752,21 @@ def _lower_complex_istft(op: Operation):
750
752
is_complex = types .is_complex (op .input .dtype )
751
753
752
754
# check parameters for validity
755
+ if is_complex :
756
+ raise ValueError ("Only complex inputs are allowed" )
753
757
if op .win_length and op .win_length .val > op .n_fft .val :
754
758
raise ValueError ("Window length must be less than or equal to n_fft" )
755
- if is_complex and op .onesided and op .onesided .val :
756
- raise ValueError ("Onesided is only valid for real inputs " )
759
+ if op . return_complex and op .onesided and op .onesided .val :
760
+ raise ValueError ("Complex output is not compatible with onesided " )
757
761
758
762
real , imag = _istft (
759
- op .input .real if is_complex else op .input ,
760
- op .input .imag if is_complex else None ,
761
- op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
763
+ op .input .real , op .input .imag ,
764
+ op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762
765
763
- return _wrap_complex_output (op .outputs [0 ], real , imag )
766
+ if op .return_complex :
767
+ return _wrap_complex_output (op .outputs [0 ], real , imag )
768
+ else
769
+ return real
764
770
765
771
766
772
@LowerComplex .register_lower_func (op_type = "complex_shape" )
0 commit comments