Skip to content

Commit 37e8115

Browse files
authored
feat: LPC CUDA kernel (#24)
* feat: CUDA kernels for LPC and complex LPC computation * refactor: backend selection logic * refactor: streamline CUDA and CPU runner assignments in recurrence.py * test: update lpc equivalence test for cuda device
1 parent c662e34 commit 37e8115

File tree

4 files changed

+244
-40
lines changed

4 files changed

+244
-40
lines changed

tests/test_extension.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,32 @@ def test_scan_equiv(samples: int, cmplx: bool, device: str):
6464
).item()
6565

6666

67-
@pytest.mark.parametrize(
68-
"samples",
69-
[1024],
70-
)
67+
@pytest.mark.parametrize("samples", [1021, 4097])
7168
@pytest.mark.parametrize(
7269
"cmplx",
7370
[True, False],
7471
)
75-
def test_lpc_equiv(samples: int, cmplx: bool):
72+
@pytest.mark.parametrize(
73+
"device",
74+
[
75+
"cpu",
76+
pytest.param(
77+
"cuda",
78+
marks=pytest.mark.skipif(
79+
not torch.cuda.is_available(), reason="CUDA not available"
80+
),
81+
),
82+
],
83+
)
84+
def test_lpc_equiv(samples: int, cmplx: bool, device: str):
7685
batch_size = 4
7786
x, A, zi = tuple(
78-
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
87+
x.to(device) for x in create_test_inputs(batch_size, samples, cmplx)
7988
)
80-
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
89+
if device == "cuda":
90+
numba_y = lpc_cuda(x, A, zi)
91+
else:
92+
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
8193
ext_y = torch.ops.torchlpc.lpc(x, A, zi)
8294

8395
assert torch.allclose(numba_y, ext_y)

torchlpc/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,21 @@ def lpc_np(x: np.ndarray, A: np.ndarray, zi: np.ndarray) -> np.ndarray:
159159
class LPC(Function):
160160
@staticmethod
161161
def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
162-
if x.is_cuda:
163-
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
164-
elif EXTENSION_LOADED:
162+
if EXTENSION_LOADED:
165163
y = torch.ops.torchlpc.lpc(x, A, zi)
166164
else:
167165
warnings.warn(
168166
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."
169167
)
170-
y = lpc_np(
171-
x.detach().cpu().numpy(),
172-
A.detach().cpu().numpy(),
173-
zi.detach().cpu().numpy(),
174-
)
175-
y = torch.from_numpy(y).to(x.device, x.dtype)
168+
if x.is_cuda:
169+
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
170+
else:
171+
y = lpc_np(
172+
x.detach().cpu().numpy(),
173+
A.detach().cpu().numpy(),
174+
zi.detach().cpu().numpy(),
175+
)
176+
y = torch.from_numpy(y).to(x.device, x.dtype)
176177
return y
177178

178179
@staticmethod

torchlpc/csrc/cuda/lpc.cu

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#include <assert.h>
2+
#include <c10/cuda/CUDAException.h>
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <stdio.h>
5+
#include <torch/script.h>
6+
#include <torch/torch.h>
7+
8+
// CUDA kernel for LPC computation
9+
template <typename scalar_t>
10+
__global__ void lpc_cuda_kernel(scalar_t* padded_y, // [B, T + order]
11+
const scalar_t* A, // [B, T, order]
12+
int64_t B, int64_t T, int64_t order) {
13+
extern __shared__ char smem[];
14+
scalar_t* sm = reinterpret_cast<scalar_t*>(smem);
15+
16+
int b = blockIdx.x;
17+
int i = threadIdx.x;
18+
19+
if (b >= B || i >= order) return;
20+
21+
// Initialize shared memory with the first 'order' elements
22+
sm[i] = padded_y[b * (T + order) + i];
23+
__syncthreads();
24+
25+
int circular_idx = 0;
26+
for (int t = 0; t < T; ++t) {
27+
circular_idx = t % order;
28+
scalar_t a = -A[((b * T + t) * order) + i];
29+
30+
// Compute s as in the Python code
31+
int idx_offset = circular_idx - i - 1;
32+
if (i > circular_idx - 1) {
33+
idx_offset += order;
34+
}
35+
scalar_t s = sm[(idx_offset + order) % order];
36+
37+
scalar_t v = a * s;
38+
39+
if (i == order - 1) {
40+
sm[circular_idx] = v;
41+
v = padded_y[b * (T + order) + t + order];
42+
}
43+
__syncthreads();
44+
45+
// Atomic add to shared memory
46+
atomicAdd(&sm[circular_idx], v);
47+
__syncthreads();
48+
49+
if (i == order - 1) {
50+
padded_y[b * (T + order) + t + order] = sm[circular_idx];
51+
}
52+
__syncthreads();
53+
}
54+
}
55+
// CUDA kernel for complex LPC computation
56+
template <typename scalar_t>
57+
__global__ void lpc_cuda_kernel_complex(
58+
scalar_t* padded_y_real, // [B, T + order]
59+
scalar_t* padded_y_imag, // [B, T + order]
60+
const scalar_t* A_real, // [B, T, order]
61+
const scalar_t* A_imag, // [B, T, order]
62+
int64_t B, int64_t T, int64_t order) {
63+
extern __shared__ char smem[];
64+
scalar_t* sm_real = reinterpret_cast<scalar_t*>(smem);
65+
scalar_t* sm_imag = sm_real + order;
66+
67+
int b = blockIdx.x;
68+
int i = threadIdx.x;
69+
70+
if (b >= B || i >= order) return;
71+
72+
// Initialize shared memory with the first 'order' elements
73+
sm_real[i] = padded_y_real[b * (T + order) + i];
74+
sm_imag[i] = padded_y_imag[b * (T + order) + i];
75+
__syncthreads();
76+
77+
int circular_idx = 0;
78+
for (int t = 0; t < T; ++t) {
79+
circular_idx = t % order;
80+
scalar_t a_real = -A_real[((b * T + t) * order) + i];
81+
scalar_t a_imag = -A_imag[((b * T + t) * order) + i];
82+
83+
int idx_offset = circular_idx - i - 1;
84+
if (i > circular_idx - 1) {
85+
idx_offset += order;
86+
}
87+
int s_idx = (idx_offset + order) % order;
88+
scalar_t s_real = sm_real[s_idx];
89+
scalar_t s_imag = sm_imag[s_idx];
90+
91+
// Complex multiply: v = a * s
92+
scalar_t v_real = a_real * s_real - a_imag * s_imag;
93+
scalar_t v_imag = a_real * s_imag + a_imag * s_real;
94+
95+
if (i == order - 1) {
96+
sm_real[circular_idx] = v_real;
97+
sm_imag[circular_idx] = v_imag;
98+
v_real = padded_y_real[b * (T + order) + t + order];
99+
v_imag = padded_y_imag[b * (T + order) + t + order];
100+
}
101+
__syncthreads();
102+
103+
atomicAdd(&sm_real[circular_idx], v_real);
104+
atomicAdd(&sm_imag[circular_idx], v_imag);
105+
__syncthreads();
106+
107+
if (i == order - 1) {
108+
padded_y_real[b * (T + order) + t + order] = sm_real[circular_idx];
109+
padded_y_imag[b * (T + order) + t + order] = sm_imag[circular_idx];
110+
}
111+
__syncthreads();
112+
}
113+
}
114+
115+
at::Tensor lpc_cuda_wrapper(const at::Tensor& x, const at::Tensor& a,
116+
const at::Tensor& zi) {
117+
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
118+
"Input must be floating point or complex");
119+
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
120+
"Coefficients must have the same scalar type as input");
121+
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
122+
"Initial conditions must have the same scalar type as input");
123+
124+
TORCH_CHECK(x.dim() == 2, "Input must be 2D");
125+
TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D");
126+
TORCH_CHECK(x.size(0) == zi.size(0),
127+
"Batch size of input and initial conditions must match");
128+
129+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
130+
131+
auto a_contiguous = a.contiguous();
132+
133+
at::Tensor out;
134+
auto order = a_contiguous.size(2);
135+
assert(order <= 1024 && "LPC order must be less than or equal to 1024");
136+
auto threads_per_block = order;
137+
138+
if (x.is_floating_point()) {
139+
out = at::cat({zi.flip(1), x}, 1).contiguous();
140+
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "lpc_cuda", [&] {
141+
auto padded_y = out.mutable_data_ptr<scalar_t>();
142+
auto A = a_contiguous.const_data_ptr<scalar_t>();
143+
auto B = x.size(0);
144+
auto T = x.size(1);
145+
146+
lpc_cuda_kernel<scalar_t><<<B, threads_per_block,
147+
threads_per_block * sizeof(scalar_t)>>>(
148+
padded_y, A, B, T, order);
149+
});
150+
} else {
151+
auto out_real =
152+
at::cat({at::real(zi).flip(1), at::real(x)}, 1).contiguous();
153+
auto out_imag =
154+
at::cat({at::imag(zi).flip(1), at::imag(x)}, 1).contiguous();
155+
auto a_real = at::real(a_contiguous).contiguous();
156+
auto a_imag = at::imag(a_contiguous).contiguous();
157+
AT_DISPATCH_FLOATING_TYPES(
158+
out_real.scalar_type(), "lpc_cuda_complex", [&] {
159+
auto padded_y_real = out_real.mutable_data_ptr<scalar_t>();
160+
auto padded_y_imag = out_imag.mutable_data_ptr<scalar_t>();
161+
auto A_real = a_real.const_data_ptr<scalar_t>();
162+
auto A_imag = a_imag.const_data_ptr<scalar_t>();
163+
auto B = x.size(0);
164+
auto T = x.size(1);
165+
166+
lpc_cuda_kernel_complex<scalar_t>
167+
<<<B, threads_per_block,
168+
2 * threads_per_block * sizeof(scalar_t)>>>(
169+
padded_y_real, padded_y_imag, A_real, A_imag, B, T,
170+
order);
171+
});
172+
out = at::view_as_complex(at::stack({out_real, out_imag}, -1));
173+
}
174+
return out.slice(1, order, out.size(1)).contiguous();
175+
}
176+
177+
TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("lpc", &lpc_cuda_wrapper); }

torchlpc/recurrence.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,48 @@
88
from .core import lpc_cuda, lpc_np
99
from . import EXTENSION_LOADED
1010

11+
if EXTENSION_LOADED:
12+
lpc_cuda_runner = torch.ops.torchlpc.lpc
13+
lpc_cpu_runner = torch.ops.torchlpc.lpc
14+
15+
scan_cuda_runner = torch.ops.torchlpc.scan
16+
scan_cpu_runner = torch.ops.torchlpc.scan
17+
else:
18+
lpc_cuda_runner = lpc_cuda
19+
lpc_cpu_runner = lambda x, A, zi: torch.from_numpy(
20+
lpc_np(x.detach().numpy(), A.detach().numpy(), zi.detach().numpy())
21+
)
22+
23+
scan_cuda_runner = lambda impulse, decay, initial_state: (
24+
lambda out: (
25+
out,
26+
compute_linear_recurrence(
27+
cuda.as_cuda_array(decay.detach()),
28+
cuda.as_cuda_array(impulse.detach()),
29+
cuda.as_cuda_array(initial_state.detach()),
30+
cuda.as_cuda_array(out),
31+
decay.shape[0],
32+
decay.shape[1],
33+
),
34+
)
35+
)(torch.empty_like(impulse))[0]
36+
scan_cpu_runner = lambda impulse, decay, initial_state: torch.from_numpy(
37+
lpc_np(
38+
impulse.detach().numpy(),
39+
-decay.unsqueeze(2).detach().numpy(),
40+
initial_state.unsqueeze(1).detach().numpy(),
41+
)
42+
)
43+
1144

1245
def _cuda_recurrence(
1346
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
1447
) -> torch.Tensor:
1548
n_dims, n_steps = decay.shape
1649
if n_dims * WARPSIZE < n_steps:
17-
if EXTENSION_LOADED:
18-
runner = torch.ops.torchlpc.scan
19-
else:
20-
21-
def runner(impulse, decay, initial_state):
22-
out = torch.empty_like(impulse)
23-
compute_linear_recurrence(
24-
cuda.as_cuda_array(decay.detach()),
25-
cuda.as_cuda_array(impulse.detach()),
26-
cuda.as_cuda_array(initial_state.detach()),
27-
cuda.as_cuda_array(out),
28-
n_dims,
29-
n_steps,
30-
)
31-
return out
32-
50+
runner = scan_cuda_runner
3351
else:
34-
runner = lambda impulse, decay, initial_state: lpc_cuda(
52+
runner = lambda impulse, decay, initial_state: lpc_cuda_runner(
3553
impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)
3654
)
3755
return runner(impulse, decay, initial_state)
@@ -44,14 +62,10 @@ def _cpu_recurrence(
4462
n_dims, _ = decay.shape
4563
# This is just a rough estimation of the computational cost
4664
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
47-
runner = torch.ops.torchlpc.scan
65+
runner = scan_cpu_runner
4866
else:
49-
runner = lambda impulse, decay, initial_state: torch.from_numpy(
50-
lpc_np(
51-
impulse.detach().numpy(),
52-
-decay.unsqueeze(2).detach().numpy(),
53-
initial_state.unsqueeze(1).detach().numpy(),
54-
)
67+
runner = lambda impulse, decay, initial_state: lpc_cpu_runner(
68+
impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)
5569
)
5670
return runner(impulse, decay, initial_state)
5771

0 commit comments

Comments
 (0)