Skip to content

Commit 61b2475

Browse files
committed
Respect core type shape in gradient of Blockwise
1 parent 94b0e6b commit 61b2475

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

pytensor/tensor/blockwise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,10 @@ def as_core(t, core_t):
355355

356356
with config.change_flags(compute_test_value="off"):
357357
safe_inputs = [
358-
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
358+
tensor(
359+
dtype=inp.type.dtype,
360+
shape=inp.type.shape[inp.type.ndim - len(sig) :],
361+
)
359362
for inp, sig in zip(inputs, self.inputs_sig, strict=True)
360363
]
361364
core_node = self._create_dummy_core_node(safe_inputs)

tests/tensor/test_blockwise.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import scipy.linalg
77

88
import pytensor
9-
from pytensor import In, config, function
9+
from pytensor import In, config, function, scan
1010
from pytensor.compile import get_default_mode, get_mode
1111
from pytensor.gradient import grad
1212
from pytensor.graph import Apply, Op
13-
from pytensor.graph.replace import vectorize_node
13+
from pytensor.graph.replace import vectorize_graph, vectorize_node
1414
from pytensor.raise_op import assert_op
1515
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
@@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
650650
np.ones(12, dtype=config.floatX),
651651
strict=True,
652652
)
653+
654+
655+
def test_blockwise_grad_core_type():
656+
class StrictCoreTypeOp(Op):
657+
def make_node(self, x):
658+
assert x.type.shape[-1] == 2
659+
return Apply(self, [x], [x.type()])
660+
661+
def perform(self, node, inputs, output_storage):
662+
output_storage[0][0] = inputs[0] + 1
663+
664+
def L_op(self, inputs, outputs, output_grads):
665+
[x] = inputs
666+
assert x.type.shape == (2,)
667+
return [x.zeros_like()]
668+
669+
strict_core_type_op = StrictCoreTypeOp()
670+
block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)")
671+
672+
x = tensor("x", shape=(5, 2), dtype="float64")
673+
y = block_strict_core_type_op(x)
674+
assert y.type.shape == (5, 2)
675+
676+
grad_y = grad(y.sum(), x)
677+
assert grad_y.type.shape == (5, 2)
678+
np.testing.assert_allclose(
679+
grad_y.eval({x: np.ones((5, 2))}),
680+
np.zeros((5, 2)),
681+
)
682+
683+
684+
def test_scan_gradient_core_type():
685+
n_steps = 3
686+
seq = tensor("seq", shape=(n_steps, 1))
687+
out, _ = scan(
688+
lambda s: s,
689+
sequences=[seq],
690+
n_steps=n_steps,
691+
)
692+
693+
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1))
694+
vec_out = vectorize_graph(out, replace={seq: vec_seq})
695+
grad_sit_sot0 = grad(vec_out.sum(), vec_seq)
696+
697+
np.testing.assert_allclose(
698+
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
699+
np.ones((4, n_steps, 1)),
700+
)

0 commit comments

Comments
 (0)