Skip to content

Commit 5e7a914

Browse files
committed
refactor: add warning for missing _C*.so file and check extension loading in Recurrence
1 parent 1472acf commit 5e7a914

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

torchlpc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from typing import Optional
33
from pathlib import Path
4+
import warnings
45

56
so_files = list(Path(__file__).parent.glob("_C*.so"))
67
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
@@ -10,6 +11,7 @@
1011
elif len(so_files) > 1:
1112
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
1213
else:
14+
warnings.warn("No _C*.so file found. Custom extension not loaded.")
1315
EXTENSION_LOADED = False
1416

1517
from .core import LPC

torchlpc/recurrence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .parallel_scan import compute_linear_recurrence, WARPSIZE
88
from .core import lpc_cuda, lpc_np
9+
from . import EXTENSION_LOADED
910

1011

1112
class Recurrence(Function):
@@ -32,7 +33,7 @@ def forward(
3233
else:
3334
num_threads = torch.get_num_threads()
3435
# This is just a rough estimation of the computational cost
35-
if min(n_dims, num_threads) < num_threads / 3:
36+
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
3637
out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state)
3738
else:
3839
out = torch.from_numpy(

0 commit comments

Comments
 (0)