Skip to content

Commit 5b63faf

Browse files
committed
Benchmark partial jacobian
1 parent 9a31837 commit 5b63faf

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

tests/test_gradient.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytensor.graph.null_type import NullType
3333
from pytensor.graph.op import Op
3434
from pytensor.scan.op import Scan
35-
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, tanh
35+
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh
3636
from pytensor.tensor.math import sum as pt_sum
3737
from pytensor.tensor.random import RandomStream
3838
from pytensor.tensor.type import (
@@ -1143,6 +1143,24 @@ def test_benchmark(self, vectorize, benchmark):
11431143
fn = function([x], jac_y, trust_input=True)
11441144
benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype))
11451145

1146+
def test_benchmark_partial_jacobian(self, vectorize, benchmark):
1147+
# Example from https://github.com/jax-ml/jax/discussions/5904#discussioncomment-422956
1148+
N = 1000
1149+
rng = np.random.default_rng(2025)
1150+
x_test = rng.random((N,))
1151+
1152+
f_mat = rng.random((N, N))
1153+
x = vector("x", dtype="float64")
1154+
1155+
def f(x):
1156+
return sqrt(f_mat @ x / N)
1157+
1158+
full_jacobian = jacobian(f(x), x, vectorize=vectorize)
1159+
partial_jacobian = full_jacobian[:5, :5]
1160+
1161+
f = pytensor.function([x], partial_jacobian, trust_input=True)
1162+
benchmark(f, x_test)
1163+
11461164

11471165
def test_hessian():
11481166
x = vector()

0 commit comments

Comments
 (0)