Skip to content

Fix shape issues in jax tridiagonal solve #1414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,21 @@ def solve(a, b):
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
# jax requires dl and du to have the same shape as d
dl = jax.numpy.pad(dl, (1, 0))
du = jax.numpy.pad(du, (0, 1))
# if b is a vector, broadcast it to be a matrix
b_is_vec = len(b.shape) == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to check this at runtime. The Solve Op has a property b_ndim, so you can do:

b_is_vec = op.b_ndim

if assume_a == 'tridiagonal':
    ... # carry on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check as written will also fail in the batched case (that's why we have it at the Op level)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'd have to do b_is_vec = op.b_ndim == 1 though, no? because bool(op.b_ndim) -> True for b_ndim > 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, my code has an error

if b_is_vec:
b = jax.numpy.expand_dims(b, -1)

res = jax.lax.linalg.tridiagonal_solve(dl, d, du, b)

if b_is_vec:
# if b is a vector, return a vector
return res.flatten()
else:
return res

else:
if assume_a not in ("gen", "sym", "her", "pos"):
Expand Down
32 changes: 32 additions & 0 deletions tests/link/jax/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,38 @@ def test_jax_solve():
)


def test_jax_tridiagonal_solve():
N = 10
A = pt.matrix("A", shape=(N, N))
b = pt.vector("b", shape=(N,))

out = pt.linalg.solve(A, b, assume_a="tridiagonal")

A_val = np.eye(N)
for i in range(N - 1):
A_val[i, i + 1] = np.random.randn()
A_val[i + 1, i] = np.random.randn()

b_val = np.random.randn(N)

compare_jax_and_py(
[A, b],
[out],
[A_val, b_val],
)

b_ = pt.matrix("b", shape=(N, 2))

out = pt.linalg.solve(A, b_, assume_a="tridiagonal")
b_val = np.random.randn(N, 2)

compare_jax_and_py(
[A, b_],
[out],
[A_val, b_val],
)


def test_jax_SolveTriangular():
rng = np.random.default_rng(utt.fetch_seed())

Expand Down