11import torch
2- from typing import Optional
2+ from typing import Optional , Union , Tuple
33from pathlib import Path
44import warnings
55
3131
3232
3333def sample_wise_lpc (
34- x : torch .Tensor , a : torch .Tensor , zi : Optional [torch .Tensor ] = None
35- ) -> torch .Tensor :
34+ x : torch .Tensor ,
35+ a : torch .Tensor ,
36+ zi : Optional [torch .Tensor ] = None ,
37+ return_zf : bool = False ,
38+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
3639 """Compute LPC filtering sample-wise.
3740
3841 Args:
3942 x (torch.Tensor): Input signal.
4043 a (torch.Tensor): LPC coefficients.
4144 zi (torch.Tensor): Initial conditions.
45+ return_zf (bool): If True, return the final filter delay values. Defaults to False.
4246
4347 Shape:
4448 - x: :math:`(B, T)`
4549 - a: :math:`(B, T, order)`
4650 - zi: :math:`(B, order)`
4751
4852 Returns:
49- torch.Tensor: Filtered signal with the same shape as x.
53+ Filtered signal with the same shape as x if `return_zf` is False.
54+ If `return_zf` is True, returns a tuple of the filtered signal and the final delay values.
5055 """
5156 assert x .shape [0 ] == a .shape [0 ]
5257 assert x .shape [1 ] == a .shape [1 ]
@@ -62,6 +67,10 @@ def sample_wise_lpc(
6267 # if order == 1 and x.is_cuda and B * WARPSIZE < T:
6368 # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
6469 if order == 1 :
65- return Recurrence .apply (- a .squeeze (2 ), x , zi .squeeze (1 ))
70+ y = Recurrence .apply (- a .squeeze (2 ), x , zi .squeeze (1 ))
71+ else :
72+ y = LPC .apply (x , a , zi )
6673
67- return LPC .apply (x , a , zi )
74+ if return_zf :
75+ return y , y [:, - order :].flip (1 )
76+ return y
0 commit comments