diff --git a/fft_conv_pytorch/fft_conv.py b/fft_conv_pytorch/fft_conv.py index 81a181e..c8102cd 100644 --- a/fft_conv_pytorch/fft_conv.py +++ b/fft_conv_pytorch/fft_conv.py @@ -132,10 +132,10 @@ def fft_conv( output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values - crop_slices = [slice(None), slice(None)] + [ + crop_slices = (slice(None), slice(None)) + tuple( slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2]) for i in range(2, signal.ndim) - ] + ) output = output[crop_slices].contiguous() # Optionally, add a bias term before returning.