Skip to content

Commit a954207

Browse files
committed
feat: enhance sample_wise_lpc to return zf
1 parent 37e8115 commit a954207

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

torchlpc/__init__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Optional
2+
from typing import Optional, Union, Tuple
33
from pathlib import Path
44
import warnings
55

@@ -31,22 +31,27 @@
3131

3232

3333
def sample_wise_lpc(
34-
x: torch.Tensor, a: torch.Tensor, zi: Optional[torch.Tensor] = None
35-
) -> torch.Tensor:
34+
x: torch.Tensor,
35+
a: torch.Tensor,
36+
zi: Optional[torch.Tensor] = None,
37+
return_zf: bool = False,
38+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
3639
"""Compute LPC filtering sample-wise.
3740
3841
Args:
3942
x (torch.Tensor): Input signal.
4043
a (torch.Tensor): LPC coefficients.
4144
zi (torch.Tensor): Initial conditions.
45+
return_zf (bool): If True, return the final filter delay values. Defaults to False.
4246
4347
Shape:
4448
- x: :math:`(B, T)`
4549
- a: :math:`(B, T, order)`
4650
- zi: :math:`(B, order)`
4751
4852
Returns:
49-
torch.Tensor: Filtered signal with the same shape as x.
53+
Filtered signal with the same shape as x if `return_zf` is False.
54+
If `return_zf` is True, returns a tuple of the filtered signal and the final delay values.
5055
"""
5156
assert x.shape[0] == a.shape[0]
5257
assert x.shape[1] == a.shape[1]
@@ -62,6 +67,10 @@ def sample_wise_lpc(
6267
# if order == 1 and x.is_cuda and B * WARPSIZE < T:
6368
# return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
6469
if order == 1:
65-
return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1))
70+
y = Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1))
71+
else:
72+
y = LPC.apply(x, a, zi)
6673

67-
return LPC.apply(x, a, zi)
74+
if return_zf:
75+
return y, y[:, -order:].flip(1)
76+
return y

0 commit comments

Comments
 (0)