@@ -62,6 +62,47 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
6262 [](const std::pair<scalar_t , scalar_t > &a) { return a.second ; });
6363}
6464
65+ template <typename scalar_t >
66+ void lpc_cpu_core (const torch::Tensor &a, const torch::Tensor &padded_out) {
67+ // Ensure input dimensions are correct
68+ TORCH_CHECK (a.dim () == 3 , " a must be 3-dimensional" );
69+ TORCH_CHECK (padded_out.dim () == 2 , " out must be 2-dimensional" );
70+ TORCH_CHECK (padded_out.size (0 ) == a.size (0 ),
71+ " Batch size of out and x must match" );
72+ TORCH_CHECK (padded_out.size (1 ) == (a.size (1 ) + a.size (2 )),
73+ " Time dimension of out must match x and a" );
74+ TORCH_INTERNAL_ASSERT (a.device ().is_cpu (), " a must be on CPU" );
75+ TORCH_INTERNAL_ASSERT (padded_out.device ().is_cpu (),
76+ " Output must be on CPU" );
77+ TORCH_INTERNAL_ASSERT (padded_out.is_contiguous (),
78+ " Output must be contiguous" );
79+
80+ // Get the dimensions
81+ const auto B = a.size (0 );
82+ const auto T = a.size (1 );
83+ const auto order = a.size (2 );
84+
85+ auto a_contiguous = a.contiguous ();
86+
87+ const scalar_t *a_ptr = a_contiguous.data_ptr <scalar_t >();
88+ scalar_t *out_ptr = padded_out.data_ptr <scalar_t >();
89+
90+ at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end) {
91+ for (auto b = start; b < end; b++) {
92+ auto out_offset = b * (T + order) + order;
93+ auto a_offset = b * T * order;
94+ for (int64_t t = 0 ; t < T; t++) {
95+ scalar_t y = out_ptr[out_offset + t];
96+ for (int64_t i = 0 ; i < order; i++) {
97+ y -= a_ptr[a_offset + t * order + i] *
98+ out_ptr[out_offset + t - i - 1 ];
99+ }
100+ out_ptr[out_offset + t] = y;
101+ }
102+ }
103+ });
104+ }
105+
65106at::Tensor scan_cpu_wrapper (const at::Tensor &input, const at::Tensor &weights,
66107 const at::Tensor &initials) {
67108 TORCH_CHECK (input.is_floating_point () || input.is_complex (),
@@ -79,8 +120,33 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
79120 return output;
80121}
81122
123+ at::Tensor lpc_cpu (const at::Tensor &x, const at::Tensor &a,
124+ const at::Tensor &zi) {
125+ TORCH_CHECK (x.is_floating_point () || x.is_complex (),
126+ " Input must be floating point or complex" );
127+ TORCH_CHECK (a.scalar_type () == x.scalar_type (),
128+ " Coefficients must have the same scalar type as input" );
129+ TORCH_CHECK (zi.scalar_type () == x.scalar_type (),
130+ " Initial conditions must have the same scalar type as input" );
131+
132+ TORCH_CHECK (x.dim () == 2 , " Input must be 2D" );
133+ TORCH_CHECK (zi.dim () == 2 , " Initial conditions must be 2D" );
134+ TORCH_CHECK (x.size (0 ) == zi.size (0 ),
135+ " Batch size of input and initial conditions must match" );
136+
137+ auto out = at::cat ({zi.flip (1 ), x}, 1 ).contiguous ();
138+
139+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (
140+ x.scalar_type (), " lpc_cpu" , [&] { lpc_cpu_core<scalar_t >(a, out); });
141+ return out.slice (1 , zi.size (1 ), out.size (1 )).contiguous ();
142+ }
143+
82144TORCH_LIBRARY (torchlpc, m) {
83145 m.def (" torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor" );
146+ m.def (" torchlpc::lpc_cpu(Tensor a, Tensor b, Tensor c) -> Tensor" );
84147}
85148
86- TORCH_LIBRARY_IMPL (torchlpc, CPU, m) { m.impl (" scan_cpu" , &scan_cpu_wrapper); }
149+ TORCH_LIBRARY_IMPL (torchlpc, CPU, m) {
150+ m.impl (" scan_cpu" , &scan_cpu_wrapper);
151+ m.impl (" lpc_cpu" , &lpc_cpu);
152+ }
0 commit comments