|
32 | 32 | from pytensor.graph.null_type import NullType
|
33 | 33 | from pytensor.graph.op import Op
|
34 | 34 | from pytensor.scan.op import Scan
|
35 |
| -from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, tanh |
| 35 | +from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh |
36 | 36 | from pytensor.tensor.math import sum as pt_sum
|
37 | 37 | from pytensor.tensor.random import RandomStream
|
38 | 38 | from pytensor.tensor.type import (
|
@@ -1143,6 +1143,24 @@ def test_benchmark(self, vectorize, benchmark):
|
1143 | 1143 | fn = function([x], jac_y, trust_input=True)
|
1144 | 1144 | benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype))
|
1145 | 1145 |
|
| 1146 | + def test_benchmark_partial_jacobian(self, vectorize, benchmark): |
| 1147 | + # Example from https://github.com/jax-ml/jax/discussions/5904#discussioncomment-422956 |
| 1148 | + N = 1000 |
| 1149 | + rng = np.random.default_rng(2025) |
| 1150 | + x_test = rng.random((N,)) |
| 1151 | + |
| 1152 | + f_mat = rng.random((N, N)) |
| 1153 | + x = vector("x", dtype="float64") |
| 1154 | + |
| 1155 | + def f(x): |
| 1156 | + return sqrt(f_mat @ x / N) |
| 1157 | + |
| 1158 | + full_jacobian = jacobian(f(x), x, vectorize=vectorize) |
| 1159 | + partial_jacobian = full_jacobian[:5, :5] |
| 1160 | + |
| 1161 | + f = pytensor.function([x], partial_jacobian, trust_input=True) |
| 1162 | + benchmark(f, x_test) |
| 1163 | + |
1146 | 1164 |
|
1147 | 1165 | def test_hessian():
|
1148 | 1166 | x = vector()
|
|
0 commit comments