Skip to content

Commit 7f2c29a

Browse files
committed
MAINT: signal: bilinear_zpk array API
1 parent a75106b commit 7f2c29a

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

scipy/signal/_filter_design.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,8 +2863,11 @@ def bilinear_zpk(z, p, k, fs):
28632863
>>> plt.ylabel('Amplitude [dB]')
28642864
>>> plt.grid(True)
28652865
"""
2866-
z = atleast_1d(z)
2867-
p = atleast_1d(p)
2866+
xp = array_namespace(z, p)
2867+
2868+
z, p = map(xp.asarray, (z, p))
2869+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
2870+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
28682871

28692872
fs = _validate_fs(fs, allow_none=False)
28702873

@@ -2877,10 +2880,10 @@ def bilinear_zpk(z, p, k, fs):
28772880
p_z = (fs2 + p) / (fs2 - p)
28782881

28792882
# Any zeros that were at infinity get moved to the Nyquist frequency
2880-
z_z = append(z_z, -ones(degree))
2883+
z_z = xp.concat((z_z, -xp.ones(degree)))
28812884

28822885
# Compensate for gain change
2883-
k_z = k * real(prod(fs2 - z) / prod(fs2 - p))
2886+
k_z = k * xp.real(xp.prod(fs2 - z) / xp.prod(fs2 - p))
28842887

28852888
return z_z, p_z, k_z
28862889

scipy/signal/tests/test_filter_design.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,19 +1615,25 @@ def test_basic(self, xp):
16151615

16161616
class TestBilinear_zpk:
16171617

1618-
def test_basic(self):
1619-
z = [-2j, +2j]
1620-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1618+
def test_basic(self, xp):
1619+
z = xp.asarray([-2j, +2j])
1620+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
16211621
k = 3
16221622

16231623
z_d, p_d, k_d = bilinear_zpk(z, p, k, 10)
16241624

1625-
xp_assert_close(sort(z_d), sort([(20-2j)/(20+2j), (20+2j)/(20-2j),
1626-
-1]))
1627-
xp_assert_close(sort(p_d), sort([77/83,
1628-
(1j/2 + 39/2) / (41/2 - 1j/2),
1629-
(39/2 - 1j/2) / (1j/2 + 41/2), ]))
1630-
xp_assert_close(k_d, 9696/69803)
1625+
xp_assert_close(
1626+
_sort_cmplx(z_d, xp=xp),
1627+
_sort_cmplx([(20-2j) / (20+2j), (20+2j) / (20-2j), -1], xp=xp)
1628+
)
1629+
xp_assert_close(
1630+
_sort_cmplx(p_d, xp=xp),
1631+
_sort_cmplx(
1632+
[77/83, (1j/2 + 39/2) / (41/2 - 1j/2), (39/2 - 1j/2) / (1j/2 + 41/2)],
1633+
xp=xp
1634+
)
1635+
)
1636+
assert math.isclose(k_d, 9696/69803)
16311637

16321638

16331639
class TestPrototypeType:

0 commit comments

Comments
 (0)