From 7ab7efb6cec8596491c63fee0a0ef803c9f861ef Mon Sep 17 00:00:00 2001 From: skishore Date: Fri, 13 Jun 2025 10:28:27 +0000 Subject: [PATCH] This change fixes the following error for rocm RuntimeError: torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA. Please rebuild with cuSOLVER Checking the pytorch code aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp, it calls magma. According to magma documentation, least squares method - magma_sgels_gpu does supports only overdetermined systems https://icl.utk.edu/projectsfiles/magma/doxygen/group__magma__gels.html The proposed temporary solution uses https://docs.pytorch.org/docs/stable/generated/torch.linalg.pinv.html as a replacement for linear least squares. This change can be reverted once torch.linalg.lstsq uses hipSolver for underconstrained systems. Tested on rocm with pytest test/torchaudio_unittest/transforms/transforms_cuda_test.py -k test_inverse_melscale --- src/torchaudio/transforms/_transforms.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 802cbd3d77..9c0e58723e 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -17,6 +17,7 @@ _get_sinc_resample_kernel, _stretch_waveform, ) +from torch.utils.cpp_extension import ROCM_HOME __all__ = [] @@ -495,7 +496,11 @@ def forward(self, melspec: Tensor) -> Tensor: if self.n_mels != n_mels: raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels)) - specgram = torch.relu(torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution) + if ROCM_HOME is not None: + solution = torch.linalg.pinv(self.fb.transpose(-1, -2)[None]) @ melspec + else: + solution = torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution + specgram = torch.relu(solution) # unpack batch specgram = specgram.view(shape[:-2] + (freq, time))