Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import setuptools
import torch
from torch.utils import cpp_extension

NAME = "torchlpc"
Expand All @@ -10,6 +11,14 @@
with open("README.md", "r") as fh:
long_description = fh.read()


extra_link_args = []
extra_compile_args = {}
# check if openmp is available
if torch.backends.openmp.is_available():
extra_compile_args["cxx"] = ["-fopenmp"]
extra_link_args.append("-lgomp")

setuptools.setup(
name=NAME,
version=VERSION,
Expand All @@ -27,7 +36,12 @@
"Operating System :: OS Independent",
],
ext_modules=[
cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"])
cpp_extension.CppExtension(
"torchlpc._C",
["torchlpc/csrc/scan_cpu.cpp"],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
],
cmdclass={"build_ext": cpp_extension.BuildExtension},
)
19 changes: 19 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,22 @@ def test_scan_cpu_equiv(samples: int, cmplx: bool):
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)

assert torch.allclose(numba_y, ext_y)


@pytest.mark.parametrize(
"samples",
[1024],
)
@pytest.mark.parametrize(
"cmplx",
[True, False],
)
def test_lpc_cpu_equiv(samples: int, cmplx: bool):
batch_size = 4
x, A, zi = tuple(
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
)
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
ext_y = torch.ops.torchlpc.lpc_cpu(x, A, zi)

assert torch.allclose(numba_y, ext_y)
7 changes: 7 additions & 0 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function
from typing import Any, Tuple, Optional, Callable, List
from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128

from . import EXTENSION_LOADED

lpc_cuda_kernel_float32: Callable = None
lpc_cuda_kernel_float64: Callable = None
Expand Down Expand Up @@ -159,7 +161,12 @@ class LPC(Function):
def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
elif EXTENSION_LOADED:
y = torch.ops.torchlpc.lpc_cpu(x, A, zi)
else:
warnings.warn(
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."
)
y = lpc_np(
x.detach().cpu().numpy(),
A.detach().cpu().numpy(),
Expand Down
68 changes: 67 additions & 1 deletion torchlpc/csrc/scan_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,47 @@ void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
}

template <typename scalar_t>
void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) {
// Ensure input dimensions are correct
TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional");
TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional");
TORCH_CHECK(padded_out.size(0) == a.size(0),
"Batch size of out and x must match");
TORCH_CHECK(padded_out.size(1) == (a.size(1) + a.size(2)),
"Time dimension of out must match x and a");
TORCH_INTERNAL_ASSERT(a.device().is_cpu(), "a must be on CPU");
TORCH_INTERNAL_ASSERT(padded_out.device().is_cpu(),
"Output must be on CPU");
TORCH_INTERNAL_ASSERT(padded_out.is_contiguous(),
"Output must be contiguous");

// Get the dimensions
const auto B = a.size(0);
const auto T = a.size(1);
const auto order = a.size(2);

auto a_contiguous = a.contiguous();

const scalar_t *a_ptr = a_contiguous.data_ptr<scalar_t>();
scalar_t *out_ptr = padded_out.data_ptr<scalar_t>();

at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++) {
auto out_offset = b * (T + order) + order;
auto a_offset = b * T * order;
for (int64_t t = 0; t < T; t++) {
scalar_t y = out_ptr[out_offset + t];
for (int64_t i = 0; i < order; i++) {
y -= a_ptr[a_offset + t * order + i] *
out_ptr[out_offset + t - i - 1];
}
out_ptr[out_offset + t] = y;
}
}
});
}

at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials) {
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
Expand All @@ -79,8 +120,33 @@ at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
return output;
}

at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a,
const at::Tensor &zi) {
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
"Coefficients must have the same scalar type as input");
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
"Initial conditions must have the same scalar type as input");

TORCH_CHECK(x.dim() == 2, "Input must be 2D");
TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D");
TORCH_CHECK(x.size(0) == zi.size(0),
"Batch size of input and initial conditions must match");

auto out = at::cat({zi.flip(1), x}, 1).contiguous();

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core<scalar_t>(a, out); });
return out.slice(1, zi.size(1), out.size(1)).contiguous();
}

TORCH_LIBRARY(torchlpc, m) {
m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
m.def("torchlpc::lpc_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); }
TORCH_LIBRARY_IMPL(torchlpc, CPU, m) {
m.impl("scan_cpu", &scan_cpu_wrapper);
m.impl("lpc_cpu", &lpc_cpu);
}
Loading