|
6 | 6 | import scipy.linalg
|
7 | 7 |
|
8 | 8 | import pytensor
|
9 |
| -from pytensor import In, config, function |
| 9 | +from pytensor import In, config, function, scan |
10 | 10 | from pytensor.compile import get_default_mode, get_mode
|
11 | 11 | from pytensor.gradient import grad
|
12 | 12 | from pytensor.graph import Apply, Op
|
13 |
| -from pytensor.graph.replace import vectorize_node |
| 13 | +from pytensor.graph.replace import vectorize_graph, vectorize_node |
14 | 14 | from pytensor.raise_op import assert_op
|
15 | 15 | from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
|
16 | 16 | from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
|
@@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
|
650 | 650 | np.ones(12, dtype=config.floatX),
|
651 | 651 | strict=True,
|
652 | 652 | )
|
| 653 | + |
| 654 | + |
| 655 | +def test_blockwise_grad_core_type(): |
| 656 | + class StrictCoreTypeOp(Op): |
| 657 | + def make_node(self, x): |
| 658 | + assert x.type.shape[-1] == 2 |
| 659 | + return Apply(self, [x], [x.type()]) |
| 660 | + |
| 661 | + def perform(self, node, inputs, output_storage): |
| 662 | + output_storage[0][0] = inputs[0] + 1 |
| 663 | + |
| 664 | + def L_op(self, inputs, outputs, output_grads): |
| 665 | + [x] = inputs |
| 666 | + assert x.type.shape == (2,) |
| 667 | + return [x.zeros_like()] |
| 668 | + |
| 669 | + strict_core_type_op = StrictCoreTypeOp() |
| 670 | + block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)") |
| 671 | + |
| 672 | + x = tensor("x", shape=(5, 2), dtype="float64") |
| 673 | + y = block_strict_core_type_op(x) |
| 674 | + assert y.type.shape == (5, 2) |
| 675 | + |
| 676 | + grad_y = grad(y.sum(), x) |
| 677 | + assert grad_y.type.shape == (5, 2) |
| 678 | + np.testing.assert_allclose( |
| 679 | + grad_y.eval({x: np.ones((5, 2))}), |
| 680 | + np.zeros((5, 2)), |
| 681 | + ) |
| 682 | + |
| 683 | + |
| 684 | +def test_scan_gradient_core_type(): |
| 685 | + n_steps = 3 |
| 686 | + seq = tensor("seq", shape=(n_steps, 1)) |
| 687 | + out, _ = scan( |
| 688 | + lambda s: s, |
| 689 | + sequences=[seq], |
| 690 | + n_steps=n_steps, |
| 691 | + ) |
| 692 | + |
| 693 | + vec_seq = tensor("vec_seq", shape=(None, n_steps, 1)) |
| 694 | + vec_out = vectorize_graph(out, replace={seq: vec_seq}) |
| 695 | + grad_sit_sot0 = grad(vec_out.sum(), vec_seq) |
| 696 | + |
| 697 | + np.testing.assert_allclose( |
| 698 | + grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}), |
| 699 | + np.ones((4, n_steps, 1)), |
| 700 | + ) |
0 commit comments