Skip to content

Commit 9a31837

Browse files
committed
Define all batched dot operations as matmul
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
1 parent 480823b commit 9a31837

File tree

5 files changed

+218
-80
lines changed

5 files changed

+218
-80
lines changed

pytensor/tensor/math.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,23 +3916,7 @@ def logsumexp(x, axis=None, keepdims=False):
39163916
return log(sum(exp(x), axis=axis, keepdims=keepdims))
39173917

39183918

3919-
# Predefine all batched variations of Dot
3920-
_inner_prod = Blockwise(
3921-
_dot,
3922-
signature="(n),(n)->()",
3923-
)
3924-
3925-
_matrix_vec_prod = Blockwise(
3926-
_dot,
3927-
signature="(m,k),(k)->(m)",
3928-
)
3929-
3930-
_vec_matrix_prod = Blockwise(
3931-
_dot,
3932-
signature="(k),(k,n)->(n)",
3933-
)
3934-
3935-
_matrix_matrix_matmul = Blockwise(
3919+
_matmul = Blockwise(
39363920
_dot,
39373921
signature="(m,k),(k,n)->(m,n)",
39383922
gufunc_spec=("numpy.matmul", 2, 1),
@@ -3988,11 +3972,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
39883972
if x1.type.ndim == 1 and x2.type.ndim == 1:
39893973
out = _dot(x1, x2)
39903974
elif x1.type.ndim == 1:
3991-
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
3975+
out = vecmat(x1, x2)
39923976
elif x2.type.ndim == 1:
3993-
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
3977+
out = matvec(x1, x2)
39943978
else:
3995-
out = _matrix_matrix_matmul(x1, x2)
3979+
out = _matmul(x1, x2)
39963980

39973981
if dtype is not None:
39983982
out = out.astype(dtype)
@@ -4042,7 +4026,7 @@ def vecdot(
40424026
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
40434027
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
40444028
"""
4045-
out = _inner_prod(x1, x2)
4029+
out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))
40464030

40474031
if dtype is not None:
40484032
out = out.astype(dtype)
@@ -4091,7 +4075,7 @@ def matvec(
40914075
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
40924076
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
40934077
"""
4094-
out = _matrix_vec_prod(x1, x2)
4078+
out = matmul(x1, x2[..., None]).squeeze(-1)
40954079

40964080
if dtype is not None:
40974081
out = out.astype(dtype)
@@ -4129,18 +4113,18 @@ def vecmat(
41294113
--------
41304114
>>> import pytensor.tensor as pt
41314115
>>> # Vector-matrix product
4132-
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
4133-
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
4116+
>>> v = pt.vector("v", shape=(3,))
4117+
>>> A = pt.matrix("A", shape=(3, 4))
41344118
>>> result = pt.vecmat(v, A) # shape (4,)
41354119
>>> # Equivalent to numpy.vecmat(v, A)
41364120
>>>
41374121
>>> # Batched vector-matrix product
4138-
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
4139-
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
4122+
>>> batched_v = pt.matrix("v", shape=(2, 3))
4123+
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
41404124
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
41414125
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
41424126
"""
4143-
out = _vec_matrix_prod(x1, x2)
4127+
out = matmul(x2.mT, x1[..., None]).squeeze(-1)
41444128

41454129
if dtype is not None:
41464130
out = out.astype(dtype)
@@ -4155,18 +4139,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
41554139
old_y_ndim = old_y.type.ndim
41564140
match (old_x_ndim, old_y_ndim):
41574141
case (1, 1):
4158-
batch_op = _inner_prod
4142+
batch_fn = vecdot
41594143
case (2, 1):
4160-
batch_op = _matrix_vec_prod
4144+
batch_fn = matvec
41614145
case (1, 2):
4162-
batch_op = _vec_matrix_prod
4146+
batch_fn = vecmat
41634147
case (2, 2):
4164-
batch_op = _matrix_matrix_matmul
4148+
batch_fn = matmul
41654149
case _:
41664150
raise ValueError(
41674151
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
41684152
)
4169-
return batch_op(batched_x, batched_y).owner
4153+
return batch_fn(batched_x, batched_y).owner
41704154

41714155

41724156
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):

pytensor/tensor/rewriting/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
from pytensor.tensor.exceptions import NotScalarConstantError
9999
from pytensor.tensor.math import (
100100
Dot,
101-
_matrix_matrix_matmul,
101+
_matmul,
102102
add,
103103
mul,
104104
neg,
@@ -908,7 +908,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
908908

909909

910910
@register_specialize
911-
@node_rewriter([_matrix_matrix_matmul])
911+
@node_rewriter([_matmul])
912912
def specialize_matmul_to_batched_dot(fgraph, node):
913913
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
914914

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.tensor.blas import Dot22
2727
from pytensor.tensor.blockwise import Blockwise
2828
from pytensor.tensor.elemwise import DimShuffle, Elemwise
29-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
29+
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
3030
from pytensor.tensor.nlinalg import (
3131
SVD,
3232
KroneckerProduct,
@@ -282,7 +282,7 @@ def cholesky_ldotlt(fgraph, node):
282282
# This rewrite only applies to matrix Dot
283283
and A.owner.inputs[0].type.ndim == 2
284284
)
285-
or (A.owner.op == _matrix_matrix_matmul)
285+
or (A.owner.op == _matmul)
286286
)
287287
):
288288
return

pytensor/tensor/rewriting/math.py

Lines changed: 116 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
as_tensor_variable,
2929
cast,
3030
constant,
31+
expand_dims,
3132
get_underlying_scalar_constant_value,
3233
moveaxis,
3334
ones_like,
3435
register_infer_shape,
3536
switch,
3637
zeros_like,
3738
)
38-
from pytensor.tensor.blockwise import Blockwise
3939
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4040
from pytensor.tensor.exceptions import NotScalarConstantError
4141
from pytensor.tensor.extra_ops import broadcast_arrays
@@ -44,10 +44,7 @@
4444
Prod,
4545
Sum,
4646
_conj,
47-
_inner_prod,
48-
_matrix_matrix_matmul,
49-
_matrix_vec_prod,
50-
_vec_matrix_prod,
47+
_matmul,
5148
add,
5249
digamma,
5350
dot,
@@ -195,60 +192,135 @@ def local_lift_transpose_through_dot(fgraph, node):
195192
return ret
196193

197194

195+
@register_canonicalize
198196
@register_stabilize
199197
@register_specialize
200-
@node_rewriter(tracks=[Blockwise])
198+
@node_rewriter(tracks=[_matmul])
201199
def local_batched_matmul_to_core_matmul(fgraph, node):
202-
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
200+
"""Move batch dimensions of matmul operands to core matmul
203201
204-
Example, if x has batch dimensions, but y not:
202+
Example, if x has batch dimensions that don't overlap with batch dimensions of y
205203
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
206204
207-
It also works when y has batch dimensions, but x not.
205+
It also works for batch dimensions of y that don't overlap with batch dimensions of x
208206
"""
209207

210-
# Check whether we have a matmul operation in this node
211-
if not (
212-
isinstance(node.op.core_op, Dot)
213-
and len(node.op.inputs_sig[0]) == 2
214-
and len(node.op.inputs_sig[1]) == 2
215-
):
216-
return None
217-
218208
x, y = node.inputs
219209
batch_ndim = node.op.batch_ndim(node)
220210

221-
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
222-
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
223-
y.type.broadcastable[:-2]
224-
):
225-
x_stacked = x.reshape((-1, x.shape[-1]))
226-
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
227-
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
228-
return [out]
229-
230-
# Otherwise, check if y has batch dimension, but x not
231-
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
232-
x.type.broadcastable[:-2]
233-
):
234-
# For the y batch case we need to first move the batch axes and then reshape
235-
# y.shape == (*b, k, n)
236-
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
237-
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
238-
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
239-
out_stacked_tr = out_stacked.reshape(
240-
(x.shape[-2], *y.shape[:-2], y.shape[-1])
241-
) # (m, *b, n)
242-
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
243-
return [out]
244-
245-
# Both x and y have batch dimensions, nothing to do here
246-
return None
211+
x_axis_to_merge = [
212+
i
213+
for i, (bcast_x, bcast_y) in enumerate(
214+
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
215+
)
216+
if bcast_y and not bcast_x
217+
]
218+
219+
y_axis_to_merge = [
220+
i
221+
for i, (bcast_x, bcast_y) in enumerate(
222+
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
223+
)
224+
if bcast_x and not bcast_y
225+
]
226+
227+
if not (x_axis_to_merge or y_axis_to_merge):
228+
return None
229+
230+
x_shape = tuple(x.shape)
231+
y_shape = tuple(y.shape)
232+
x_is_row = x.type.broadcastable[-2]
233+
y_is_col = y.type.broadcastable[-1]
234+
n_x_axis_to_merge = len(x_axis_to_merge)
235+
n_y_axis_to_merge = len(y_axis_to_merge)
236+
n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
237+
238+
x_stacked, y_stacked = x, y
239+
dims_were_merged = False
240+
241+
if n_x_axis_to_merge:
242+
# ravel batch dimensions of x on the core (m) axis
243+
x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2))
244+
x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination)
245+
if x_is_row:
246+
# x was a row matrix, squeeze it to clean up the graph
247+
x_stacked = x_stacked.squeeze(-2)
248+
if n_x_axis_to_merge > 1 or not x_is_row:
249+
# Ravel moved batch dims together with (m) if needed
250+
x_stacked_shape = tuple(x_stacked.shape)
251+
x_stacked = x_stacked.reshape(
252+
(*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1])
253+
)
254+
dims_were_merged = True
255+
256+
if n_y_axis_to_merge:
257+
# ravel batch dimensions of y on the core (n) axis
258+
y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1))
259+
y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination)
260+
if y_is_col:
261+
# y was a column matrix, squeeze it to clean up the graph
262+
y_stacked = y_stacked.squeeze(-1)
263+
if n_y_axis_to_merge > 1 or not y_is_col:
264+
# Ravel moved batch dims together with (n) if needed
265+
y_stacked_shape = tuple(y_stacked.shape)
266+
y_stacked = y_stacked.reshape(
267+
(*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1)
268+
)
269+
dims_were_merged = True
270+
271+
# Squeeze x_dims corresponding to merged dimensions of y
272+
x_axis_to_squeeze = np.array(y_axis_to_merge)
273+
for i in reversed(x_axis_to_merge):
274+
# The corresponding dimensions of y may have shifted when we merged dimensions of x
275+
x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1
276+
x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze))
277+
278+
# Same for y
279+
y_axis_to_squeeze = np.array(x_axis_to_merge)
280+
for i in reversed(y_axis_to_merge):
281+
y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1
282+
y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze))
283+
284+
out_stacked = x_stacked @ y_stacked
285+
286+
# Split back any merged dimensions
287+
if dims_were_merged:
288+
x_merged_shapes = [x_shape[i] for i in x_axis_to_merge]
289+
if not x_is_row:
290+
# Otherwise we handle that later with expand_dims, which is cleaner
291+
x_merged_shapes.append(x_shape[-2])
292+
y_merged_shapes = [y_shape[i] for i in y_axis_to_merge]
293+
if not y_is_col:
294+
# Otherwise we handle that later with expand_dims, which is cleaner
295+
y_merged_shapes.append(y_shape[-1])
296+
out_stacked_shape = tuple(out_stacked.shape)
297+
out_unstacked = out_stacked.reshape(
298+
(
299+
*out_stacked_shape[: batch_ndim - n_axis_to_merge],
300+
*x_merged_shapes,
301+
*y_merged_shapes,
302+
)
303+
)
304+
else:
305+
out_unstacked = out_stacked
306+
307+
# Add back dummy row, col axis
308+
# We do this separately to avoid the reshape as much as we can
309+
if y_is_col and (n_y_axis_to_merge or dims_were_merged):
310+
out_unstacked = expand_dims(out_unstacked, -1)
311+
if x_is_row and (n_x_axis_to_merge or dims_were_merged):
312+
out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2)
313+
314+
# Move batch axis back to their original location
315+
source = range(-n_axis_to_merge - 2, 0)
316+
destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1)
317+
out = moveaxis(out_unstacked, source, destination)
318+
return [out]
247319

248320

249321
@register_canonicalize
250322
@register_specialize
251-
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
323+
@node_rewriter([_matmul])
252324
def local_blockwise_dot_to_mul(fgraph, node):
253325
"""Rewrite blockwise dots that correspond to multiplication without summation.
254326

0 commit comments

Comments
 (0)