Skip to content

Commit d372cee

Browse files
authored
feat: cpp extension for LPC (#18)
* feat: add LPC CPU implementation and wrapper function * test: add equivalence tests for lpc_cpu function * feat: openmp compilation flag * feat: use cpp lpc and add deprecation warning
1 parent b9c9307 commit d372cee

File tree

4 files changed

+108
-2
lines changed

4 files changed

+108
-2
lines changed

setup.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import setuptools
2+
import torch
23
from torch.utils import cpp_extension
34

45
NAME = "torchlpc"
@@ -10,6 +11,14 @@
1011
with open("README.md", "r") as fh:
1112
long_description = fh.read()
1213

14+
15+
extra_link_args = []
16+
extra_compile_args = {}
17+
# check if openmp is available
18+
if torch.backends.openmp.is_available():
19+
extra_compile_args["cxx"] = ["-fopenmp"]
20+
extra_link_args.append("-lgomp")
21+
1322
setuptools.setup(
1423
name=NAME,
1524
version=VERSION,
@@ -27,7 +36,12 @@
2736
"Operating System :: OS Independent",
2837
],
2938
ext_modules=[
30-
cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"])
39+
cpp_extension.CppExtension(
40+
"torchlpc._C",
41+
["torchlpc/csrc/scan_cpu.cpp"],
42+
extra_compile_args=extra_compile_args,
43+
extra_link_args=extra_link_args,
44+
)
3145
],
3246
cmdclass={"build_ext": cpp_extension.BuildExtension},
3347
)

tests/test_extension.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,22 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
3333
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)
3434

3535
assert torch.allclose(numba_y, ext_y)
36+
37+
38+
@pytest.mark.parametrize(
39+
"samples",
40+
[1024],
41+
)
42+
@pytest.mark.parametrize(
43+
"cmplx",
44+
[True, False],
45+
)
46+
def test_lpc_cpu_equiv(samples: int, cmplx: bool):
47+
batch_size = 4
48+
x, A, zi = tuple(
49+
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
50+
)
51+
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
52+
ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
53+
54+
assert torch.allclose(numba_y, ext_y)

torchlpc/core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import warnings
12
import torch
23
import numpy as np
34
import torch.nn.functional as F
45
from torch.autograd import Function
56
from typing import Any, Tuple, Optional, Callable, List
67
from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128
78

9+
from . import EXTENSION_LOADED
810

911
lpc_cuda_kernel_float32: Callable = None
1012
lpc_cuda_kernel_float64: Callable = None
@@ -159,7 +161,12 @@ class LPC(Function):
159161
def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
160162
if x.is_cuda:
161163
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
164+
elif EXTENSION_LOADED:
165+
y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
162166
else:
167+
warnings.warn(
168+
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."
169+
)
163170
y = lpc_np(
164171
x.detach().cpu().numpy(),
165172
A.detach().cpu().numpy(),

torchlpc/csrc/scan_cpu.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,47 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
6262
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
6363
}
6464

65+
template <typename scalar_t>
66+
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
67+
// Ensure input dimensions are correct
68+
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
69+
TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional");
70+
TORCH_CHECK(padded_out.size(0) == a.size(0),
71+
"Batch size of out and x must match");
72+
TORCH_CHECK(padded_out.size(1) == (a.size(1) + a.size(2)),
73+
"Time dimension of out must match x and a");
74+
TORCH_INTERNAL_ASSERT(a.device().is_cpu(), "a must be on CPU");
75+
TORCH_INTERNAL_ASSERT(padded_out.device().is_cpu(),
76+
"Output must be on CPU");
77+
TORCH_INTERNAL_ASSERT(padded_out.is_contiguous(),
78+
"Output must be contiguous");
79+
80+
// Get the dimensions
81+
const auto B = a.size(0);
82+
const auto T = a.size(1);
83+
const auto order = a.size(2);
84+
85+
auto a_contiguous = a.contiguous();
86+
87+
const scalar_t *a_ptr = a_contiguous.data_ptr<scalar_t>();
88+
scalar_t *out_ptr = padded_out.data_ptr<scalar_t>();
89+
90+
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) {
91+
for (auto b = start; b < end; b++) {
92+
auto out_offset = b * (T + order) + order;
93+
auto a_offset = b * T * order;
94+
for (int64_t t = 0; t < T; t++) {
95+
scalar_t y = out_ptr[out_offset + t];
96+
for (int64_t i = 0; i < order; i++) {
97+
y -= a_ptr[a_offset + t * order + i] *
98+
out_ptr[out_offset + t - i - 1];
99+
}
100+
out_ptr[out_offset + t] = y;
101+
}
102+
}
103+
});
104+
}
105+
65106
at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
66107
const at::Tensor &initials) {
67108
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
@@ -79,8 +120,33 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
79120
return output;
80121
}
81122

123+
at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
124+
const at::Tensor &zi) {
125+
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
126+
"Input must be floating point or complex");
127+
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
128+
"Coefficients must have the same scalar type as input");
129+
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
130+
"Initial conditions must have the same scalar type as input");
131+
132+
TORCH_CHECK(x.dim() == 2, "Input must be 2D");
133+
TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D");
134+
TORCH_CHECK(x.size(0) == zi.size(0),
135+
"Batch size of input and initial conditions must match");
136+
137+
auto out = at::cat({zi.flip(1), x}, 1).contiguous();
138+
139+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
140+
x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core<scalar_t>(a, out); });
141+
return out.slice(1, zi.size(1), out.size(1)).contiguous();
142+
}
143+
82144
TORCH_LIBRARY(torchlpc, m) {
83145
m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
146+
m.def("torchlpc::lpc_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
84147
}
85148

86-
TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); }
149+
TORCH_LIBRARY_IMPL(torchlpc, CPU, m) {
150+
m.impl("scan_cpu", &scan_cpu_wrapper);
151+
m.impl("lpc_cpu", &lpc_cpu);
152+
}

0 commit comments

Comments
 (0)