1+ #include < assert.h>
2+ #include < c10/cuda/CUDAException.h>
3+ #include < c10/cuda/CUDAGuard.h>
4+ #include < stdio.h>
5+ #include < torch/script.h>
6+ #include < torch/torch.h>
7+
8+ // CUDA kernel for LPC computation
9+ template <typename scalar_t >
10+ __global__ void lpc_cuda_kernel (scalar_t * padded_y, // [B, T + order]
11+ const scalar_t * A, // [B, T, order]
12+ int64_t B, int64_t T, int64_t order) {
13+ extern __shared__ char smem[];
14+ scalar_t * sm = reinterpret_cast <scalar_t *>(smem);
15+
16+ int b = blockIdx .x ;
17+ int i = threadIdx .x ;
18+
19+ if (b >= B || i >= order) return ;
20+
21+ // Initialize shared memory with the first 'order' elements
22+ sm[i] = padded_y[b * (T + order) + i];
23+ __syncthreads ();
24+
25+ int circular_idx = 0 ;
26+ for (int t = 0 ; t < T; ++t) {
27+ circular_idx = t % order;
28+ scalar_t a = -A[((b * T + t) * order) + i];
29+
30+ // Compute s as in the Python code
31+ int idx_offset = circular_idx - i - 1 ;
32+ if (i > circular_idx - 1 ) {
33+ idx_offset += order;
34+ }
35+ scalar_t s = sm[(idx_offset + order) % order];
36+
37+ scalar_t v = a * s;
38+
39+ if (i == order - 1 ) {
40+ sm[circular_idx] = v;
41+ v = padded_y[b * (T + order) + t + order];
42+ }
43+ __syncthreads ();
44+
45+ // Atomic add to shared memory
46+ atomicAdd (&sm[circular_idx], v);
47+ __syncthreads ();
48+
49+ if (i == order - 1 ) {
50+ padded_y[b * (T + order) + t + order] = sm[circular_idx];
51+ }
52+ __syncthreads ();
53+ }
54+ }
55+ // CUDA kernel for complex LPC computation
56+ template <typename scalar_t >
57+ __global__ void lpc_cuda_kernel_complex (
58+ scalar_t * padded_y_real, // [B, T + order]
59+ scalar_t * padded_y_imag, // [B, T + order]
60+ const scalar_t * A_real, // [B, T, order]
61+ const scalar_t * A_imag, // [B, T, order]
62+ int64_t B, int64_t T, int64_t order) {
63+ extern __shared__ char smem[];
64+ scalar_t * sm_real = reinterpret_cast <scalar_t *>(smem);
65+ scalar_t * sm_imag = sm_real + order;
66+
67+ int b = blockIdx .x ;
68+ int i = threadIdx .x ;
69+
70+ if (b >= B || i >= order) return ;
71+
72+ // Initialize shared memory with the first 'order' elements
73+ sm_real[i] = padded_y_real[b * (T + order) + i];
74+ sm_imag[i] = padded_y_imag[b * (T + order) + i];
75+ __syncthreads ();
76+
77+ int circular_idx = 0 ;
78+ for (int t = 0 ; t < T; ++t) {
79+ circular_idx = t % order;
80+ scalar_t a_real = -A_real[((b * T + t) * order) + i];
81+ scalar_t a_imag = -A_imag[((b * T + t) * order) + i];
82+
83+ int idx_offset = circular_idx - i - 1 ;
84+ if (i > circular_idx - 1 ) {
85+ idx_offset += order;
86+ }
87+ int s_idx = (idx_offset + order) % order;
88+ scalar_t s_real = sm_real[s_idx];
89+ scalar_t s_imag = sm_imag[s_idx];
90+
91+ // Complex multiply: v = a * s
92+ scalar_t v_real = a_real * s_real - a_imag * s_imag;
93+ scalar_t v_imag = a_real * s_imag + a_imag * s_real;
94+
95+ if (i == order - 1 ) {
96+ sm_real[circular_idx] = v_real;
97+ sm_imag[circular_idx] = v_imag;
98+ v_real = padded_y_real[b * (T + order) + t + order];
99+ v_imag = padded_y_imag[b * (T + order) + t + order];
100+ }
101+ __syncthreads ();
102+
103+ atomicAdd (&sm_real[circular_idx], v_real);
104+ atomicAdd (&sm_imag[circular_idx], v_imag);
105+ __syncthreads ();
106+
107+ if (i == order - 1 ) {
108+ padded_y_real[b * (T + order) + t + order] = sm_real[circular_idx];
109+ padded_y_imag[b * (T + order) + t + order] = sm_imag[circular_idx];
110+ }
111+ __syncthreads ();
112+ }
113+ }
114+
115+ at::Tensor lpc_cuda_wrapper (const at::Tensor& x, const at::Tensor& a,
116+ const at::Tensor& zi) {
117+ TORCH_CHECK (x.is_floating_point () || x.is_complex (),
118+ " Input must be floating point or complex" );
119+ TORCH_CHECK (a.scalar_type () == x.scalar_type (),
120+ " Coefficients must have the same scalar type as input" );
121+ TORCH_CHECK (zi.scalar_type () == x.scalar_type (),
122+ " Initial conditions must have the same scalar type as input" );
123+
124+ TORCH_CHECK (x.dim () == 2 , " Input must be 2D" );
125+ TORCH_CHECK (zi.dim () == 2 , " Initial conditions must be 2D" );
126+ TORCH_CHECK (x.size (0 ) == zi.size (0 ),
127+ " Batch size of input and initial conditions must match" );
128+
129+ const at::cuda::OptionalCUDAGuard device_guard (device_of (x));
130+
131+ auto a_contiguous = a.contiguous ();
132+
133+ at::Tensor out;
134+ auto order = a_contiguous.size (2 );
135+ assert (order <= 1024 && " LPC order must be less than or equal to 1024" );
136+ auto threads_per_block = order;
137+
138+ if (x.is_floating_point ()) {
139+ out = at::cat ({zi.flip (1 ), x}, 1 ).contiguous ();
140+ AT_DISPATCH_FLOATING_TYPES (x.scalar_type (), " lpc_cuda" , [&] {
141+ auto padded_y = out.mutable_data_ptr <scalar_t >();
142+ auto A = a_contiguous.const_data_ptr <scalar_t >();
143+ auto B = x.size (0 );
144+ auto T = x.size (1 );
145+
146+ lpc_cuda_kernel<scalar_t ><<<B, threads_per_block,
147+ threads_per_block * sizeof (scalar_t )>>> (
148+ padded_y, A, B, T, order);
149+ });
150+ } else {
151+ auto out_real =
152+ at::cat ({at::real (zi).flip (1 ), at::real (x)}, 1 ).contiguous ();
153+ auto out_imag =
154+ at::cat ({at::imag (zi).flip (1 ), at::imag (x)}, 1 ).contiguous ();
155+ auto a_real = at::real (a_contiguous).contiguous ();
156+ auto a_imag = at::imag (a_contiguous).contiguous ();
157+ AT_DISPATCH_FLOATING_TYPES (
158+ out_real.scalar_type (), " lpc_cuda_complex" , [&] {
159+ auto padded_y_real = out_real.mutable_data_ptr <scalar_t >();
160+ auto padded_y_imag = out_imag.mutable_data_ptr <scalar_t >();
161+ auto A_real = a_real.const_data_ptr <scalar_t >();
162+ auto A_imag = a_imag.const_data_ptr <scalar_t >();
163+ auto B = x.size (0 );
164+ auto T = x.size (1 );
165+
166+ lpc_cuda_kernel_complex<scalar_t >
167+ <<<B, threads_per_block,
168+ 2 * threads_per_block * sizeof (scalar_t )>>> (
169+ padded_y_real, padded_y_imag, A_real, A_imag, B, T,
170+ order);
171+ });
172+ out = at::view_as_complex (at::stack ({out_real, out_imag}, -1 ));
173+ }
174+ return out.slice (1 , order, out.size (1 )).contiguous ();
175+ }
176+
177+ TORCH_LIBRARY_IMPL (torchlpc, CUDA, m) { m.impl (" lpc" , &lpc_cuda_wrapper); }
0 commit comments