diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index f8b7f7ff98..35dc852c77 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -101,9 +101,12 @@ PyTensor implements the :func:`pytensor.gradient.jacobian` macro that does all that is needed to compute the Jacobian. The following text explains how to do it manually. +Using Scan +---------- + In order to manually compute the Jacobian of some function ``y`` with -respect to some parameter ``x`` we need to use `scan`. What we -do is to loop over the entries in ``y`` and compute the gradient of +respect to some parameter ``x`` we can use `scan`. +In this case, we loop over the entries in ``y`` and compute the gradient of ``y[i]`` with respect to ``x``. .. note:: @@ -111,8 +114,7 @@ do is to loop over the entries in ``y`` and compute the gradient of `scan` is a generic op in PyTensor that allows writing in a symbolic manner all kinds of recurrent equations. While creating symbolic loops (and optimizing them for performance) is a hard task, - effort is being done for improving the performance of `scan`. We - shall return to :ref:`scan` later in this tutorial. + efforts are being made to improving the performance of `scan`. >>> import pytensor >>> import pytensor.tensor as pt @@ -124,9 +126,9 @@ do is to loop over the entries in ``y`` and compute the gradient of array([[ 8., 0.], [ 0., 8.]]) -What we do in this code is to generate a sequence of integers from ``0`` to -``y.shape[0]`` using `pt.arange`. Then we loop through this sequence, and -at each step, we compute the gradient of element ``y[i]`` with respect to +This code generates a sequence of integers from ``0`` to +``y.shape[0]`` using `pt.arange`. Then it loops through this sequence, and +at each step, computes the gradient of element ``y[i]`` with respect to ``x``. `scan` automatically concatenates all these rows, generating a matrix which corresponds to the Jacobian. @@ -139,6 +141,31 @@ matrix which corresponds to the Jacobian. ``x`` anymore, while ``y[i]`` still is. +Using automatic vectorization +----------------------------- +An alternative way to build the Jacobian is to vectorize the graph that computes a single row or colum of the jacobian +We can use `Lop` or `Rop` (more about it below) to obtain the row or column of the jacobian and `vectorize_graph` +to vectorize it to the full jacobian matrix. + +>>> import pytensor +>>> import pytensor.tensor as pt +>>> from pytensor.gradient import Lop +>>> from pytensor.graph import vectorize_graph +>>> x = pt.dvector('x') +>>> y = x ** 2 +>>> row_cotangent = pt.dvector("row_cotangent") # Helper variable, it will be replaced during vectorization +>>> J_row = Lop(y, x, row_cotangent) +>>> J = vectorize_graph(J_row, replace={row_cotangent: pt.eye(x.size)}) +>>> f = pytensor.function([x], J) +>>> f([4, 4]) +array([[ 8., 0.], + [ 0., 8.]]) + +This avoids the overhead of scan, at the cost of higher memory usage if the jacobian expression has large intermediate operations. +Also, not all graphs are safely vectorizable (e.g., if different rows require intermediate operations of different sizes). +For these reasons `jacobian` uses scan by default. The behavior can be changed by setting `vectorize=True`. + + Computing the Hessian ===================== diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 96a39e09d9..5924fd7fcb 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -11,7 +11,7 @@ import pytensor from pytensor.compile.ops import ViewOp from pytensor.configdefaults import config -from pytensor.graph import utils +from pytensor.graph import utils, vectorize_graph from pytensor.graph.basic import Apply, NominalVariable, Variable from pytensor.graph.null_type import NullType, null_type from pytensor.graph.op import get_test_values @@ -703,15 +703,15 @@ def grad( grad_dict[var] = g_var def handle_disconnected(var): - message = ( - "grad method was asked to compute the gradient " - "with respect to a variable that is not part of " - "the computational graph of the cost, or is used " - f"only by a non-differentiable operator: {var}" - ) if disconnected_inputs == "ignore": - pass + return elif disconnected_inputs == "warn": + message = ( + "grad method was asked to compute the gradient " + "with respect to a variable that is not part of " + "the computational graph of the cost, or is used " + f"only by a non-differentiable operator: {var}" + ) warnings.warn(message, stacklevel=2) elif disconnected_inputs == "raise": message = utils.get_variable_trace_string(var) @@ -2021,13 +2021,19 @@ def __str__(self): Exception args: {args_msg}""" -def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"): +def jacobian( + expression, + wrt, + consider_constant=None, + disconnected_inputs="raise", + vectorize=False, +): """ Compute the full Jacobian, row by row. Parameters ---------- - expression : Vector (1-dimensional) :class:`~pytensor.graph.basic.Variable` + expression :class:`~pytensor.graph.basic.Variable` Values that we are differentiating (that we want the Jacobian of) wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables Term[s] with respect to which we compute the Jacobian @@ -2051,18 +2057,18 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise output, then a zero variable is returned. The return value is of same type as `wrt`: a list/tuple or TensorVariable in all cases. """ + from pytensor.tensor.basic import eye + from pytensor.tensor.extra_ops import broadcast_to if not isinstance(expression, Variable): raise TypeError("jacobian expects a Variable as `expression`") - if expression.ndim > 1: - raise ValueError( - "jacobian expects a 1 dimensional variable as `expression`." - " If not use flatten to make it a vector" - ) - using_list = isinstance(wrt, list) using_tuple = isinstance(wrt, tuple) + grad_kwargs = { + "consider_constant": consider_constant, + "disconnected_inputs": disconnected_inputs, + } if isinstance(wrt, list | tuple): wrt = list(wrt) @@ -2070,43 +2076,55 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise wrt = [wrt] if all(expression.type.broadcastable): - # expression is just a scalar, use grad - return as_list_or_tuple( - using_list, - using_tuple, - grad( - expression.squeeze(), - wrt, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, - ), + jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs) + + elif vectorize: + expression_flat = expression.ravel() + row_tangent = _float_ones_like(expression_flat).type("row_tangent") + jacobian_single_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs) + + n_rows = expression_flat.size + jacobian_matrices = vectorize_graph( + jacobian_single_rows, + replace={row_tangent: eye(n_rows, dtype=row_tangent.dtype)}, ) + if disconnected_inputs != "raise": + # If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian + # We have to broadcast the zeros explicitly here + for i, (jacobian_single_row, jacobian_matrix) in enumerate( + zip(jacobian_single_rows, jacobian_matrices, strict=True) + ): + if jacobian_single_row.ndim == jacobian_matrix.ndim: + jacobian_matrices[i] = broadcast_to( + jacobian_matrix, shape=(n_rows, *jacobian_matrix.shape) + ) - def inner_function(*args): - idx = args[0] - expr = args[1] - rvals = [] - for inp in args[2:]: - rval = grad( - expr[idx], - inp, - consider_constant=consider_constant, - disconnected_inputs=disconnected_inputs, + else: + + def inner_function(*args): + idx, expr, *wrt = args + return grad(expr[idx], wrt, **grad_kwargs) + + jacobian_matrices, updates = pytensor.scan( + inner_function, + sequences=pytensor.tensor.arange(expression.size), + non_sequences=[expression.ravel(), *wrt], + return_list=True, + ) + if updates: + raise ValueError( + "The scan used to build the jacobian matrices returned a list of updates" ) - rvals.append(rval) - return rvals - - # Computing the gradients does not affect the random seeds on any random - # generator used n expression (because during computing gradients we are - # just backtracking over old values. (rp Jan 2012 - if anyone has a - # counter example please show me) - jacobs, updates = pytensor.scan( - inner_function, - sequences=pytensor.tensor.arange(expression.shape[0]), - non_sequences=[expression, *wrt], - ) - assert not updates, "Scan has returned a list of updates; this should not happen." - return as_list_or_tuple(using_list, using_tuple, jacobs) + + if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim): + # There was some raveling or squeezing done prior to getting the jacobians + # Reshape into original shapes + jacobian_matrices = [ + jac_matrix.reshape((*expression.shape, *w.shape)) + for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True) + ] + + return as_list_or_tuple(using_list, using_tuple, jacobian_matrices) def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 5092d55e6b..6cb46b6301 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -232,13 +232,13 @@ def vectorize_graph( def vectorize_graph( outputs: Sequence[Variable], replace: Mapping[Variable, Variable], -) -> Sequence[Variable]: ... +) -> list[Variable]: ... def vectorize_graph( outputs: Variable | Sequence[Variable], replace: Mapping[Variable, Variable], -) -> Variable | Sequence[Variable]: +) -> Variable | list[Variable]: """Vectorize outputs graph given mapping from old variables to expanded counterparts version. Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`. diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index b9e9c3164d..8225fd02ac 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3081,6 +3081,10 @@ def flatten(x, ndim=1): else: dims = (-1,) + if len(dims) == _x.ndim: + # Nothing to ravel + return _x + x_reshaped = _x.reshape(dims) shape_kept_dims = _x.type.shape[: ndim - 1] bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :]) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 714f597b32..b11b6164a8 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -3916,23 +3916,7 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -# Predefine all batched variations of Dot -_inner_prod = Blockwise( - _dot, - signature="(n),(n)->()", -) - -_matrix_vec_prod = Blockwise( - _dot, - signature="(m,k),(k)->(m)", -) - -_vec_matrix_prod = Blockwise( - _dot, - signature="(k),(k,n)->(n)", -) - -_matrix_matrix_matmul = Blockwise( +_matmul = Blockwise( _dot, signature="(m,k),(k,n)->(m,n)", gufunc_spec=("numpy.matmul", 2, 1), @@ -3988,11 +3972,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None if x1.type.ndim == 1 and x2.type.ndim == 1: out = _dot(x1, x2) elif x1.type.ndim == 1: - out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) + out = vecmat(x1, x2) elif x2.type.ndim == 1: - out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) + out = matvec(x1, x2) else: - out = _matrix_matrix_matmul(x1, x2) + out = _matmul(x1, x2) if dtype is not None: out = out.astype(dtype) @@ -4042,7 +4026,7 @@ def vecdot( >>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,) >>> # Equivalent to numpy.vecdot(x_batch, y_batch) """ - out = _inner_prod(x1, x2) + out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1)) if dtype is not None: out = out.astype(dtype) @@ -4091,7 +4075,7 @@ def matvec( >>> result = pt.matvec(batched_A, batched_v) # shape (2, 3) >>> # Equivalent to numpy.matvec(batched_A, batched_v) """ - out = _matrix_vec_prod(x1, x2) + out = matmul(x1, x2[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4129,18 +4113,18 @@ def vecmat( -------- >>> import pytensor.tensor as pt >>> # Vector-matrix product - >>> v = pt.vector("v", shape=(3,)) # shape (3,) - >>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4) + >>> v = pt.vector("v", shape=(3,)) + >>> A = pt.matrix("A", shape=(3, 4)) >>> result = pt.vecmat(v, A) # shape (4,) >>> # Equivalent to numpy.vecmat(v, A) >>> >>> # Batched vector-matrix product - >>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3) - >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4) + >>> batched_v = pt.matrix("v", shape=(2, 3)) + >>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) >>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4) >>> # Equivalent to numpy.vecmat(batched_v, batched_A) """ - out = _vec_matrix_prod(x1, x2) + out = matmul(x2.mT, x1[..., None]).squeeze(-1) if dtype is not None: out = out.astype(dtype) @@ -4155,18 +4139,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y): old_y_ndim = old_y.type.ndim match (old_x_ndim, old_y_ndim): case (1, 1): - batch_op = _inner_prod + batch_fn = vecdot case (2, 1): - batch_op = _matrix_vec_prod + batch_fn = matvec case (1, 2): - batch_op = _vec_matrix_prod + batch_fn = vecmat case (2, 2): - batch_op = _matrix_matrix_matmul + batch_fn = matmul case _: raise ValueError( f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." ) - return batch_op(batched_x, batched_y).owner + return batch_fn(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 284f4af2b8..2f693b4ce4 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -31,6 +31,7 @@ from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph from pytensor.graph.basic import Constant +from pytensor.graph.op import _NoPythonOp from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, @@ -1108,7 +1109,12 @@ def unconditional_constant_folding(fgraph, node): storage_map[o] = [None] compute_map[o] = [False] - thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) + if isinstance(node.op, _NoPythonOp): + thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) + else: + thunk = node.op.make_thunk( + node, storage_map, compute_map, no_recycling=[], impl="py" + ) required = thunk() # A node whose inputs are all provided should always return successfully diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index e626b0720b..3ab0884f6b 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -98,7 +98,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( Dot, - _matrix_matrix_matmul, + _matmul, add, mul, neg, @@ -758,7 +758,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): ignore_newtrees=False, ), "fast_run", - position=15, + position=11, ) @@ -903,12 +903,12 @@ def local_dot22_to_dot22scalar(fgraph, node): "local_dot22_to_dot22scalar", in2out(local_dot22_to_dot22scalar), "fast_run", - position=11, + position=12, ) @register_specialize -@node_rewriter([_matrix_matrix_matmul]) +@node_rewriter([_matmul]) def specialize_matmul_to_batched_dot(fgraph, node): """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 98fc4e074c..7ceea15b04 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -28,21 +28,17 @@ from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, - alloc, - cast, constant, - get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, register_canonicalize, register_specialize, + register_stabilize, ) -from pytensor.tensor.shape import shape_padleft from pytensor.tensor.variable import TensorConstant @@ -395,6 +391,7 @@ def is_dimshuffle_useless(new_order, input): @register_canonicalize +@register_stabilize @register_specialize @node_rewriter([DimShuffle]) def local_dimshuffle_lift(fgraph, node): @@ -483,66 +480,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): """ if len(node.outputs) > 1: - return - try: - shape_i = fgraph.shape_feature.shape_i - except AttributeError: - shape_i = None - if isinstance(node.op, Elemwise): - scalar_op = node.op.scalar_op - # print "aa", scalar_op.output_types_preference - if getattr(scalar_op, "output_types_preference", None) in ( - ps.upgrade_to_float, - ps.upcast_out, - ): - # this is the kind of op that we can screw with the input - # dtypes by upcasting explicitly - output_dtype = node.outputs[0].type.dtype - new_inputs = [] - for i in node.inputs: - if i.type.dtype == output_dtype: - new_inputs.append(i) - else: - try: - cval_i = get_underlying_scalar_constant_value( - i, only_process_constants=True - ) - if all(i.broadcastable): - new_inputs.append( - shape_padleft(cast(cval_i, output_dtype), i.ndim) - ) - else: - if shape_i is None: - return - new_inputs.append( - alloc( - cast(cval_i, output_dtype), - *[shape_i(d)(i) for d in range(i.ndim)], - ) - ) - # print >> sys.stderr, "AAA", - # *[Shape_i(d)(i) for d in range(i.ndim)] - except NotScalarConstantError: - # for the case of a non-scalar - if isinstance(i, TensorConstant): - new_inputs.append(cast(i, output_dtype)) - else: - new_inputs.append(i) + return None - if new_inputs != node.inputs: - rval = [node.op(*new_inputs)] - if not node.outputs[0].type.is_super(rval[0].type): - # This can happen for example when floatX=float32 - # and we do the true division between and int64 - # and a constant that will get typed as int8. + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( + ps.upgrade_to_float, + ps.upcast_out, + ): + return None - # As this is just to allow merging more case, if - # the upcast don't work, we can just skip it. - return + # this is the kind of op that we can screw with the input + # dtypes by upcasting explicitly + [old_out] = node.outputs + output_dtype = old_out.type.dtype + new_inputs = list(node.inputs) + changed = False + for i, inp in enumerate(node.inputs): + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): + new_inputs[i] = constant(inp.data.astype(output_dtype)) + changed = True + + if not changed: + return None + + rval = node.op(*new_inputs) + if not old_out.type.is_super(rval.type): + # This can happen for example when floatX=float32 + # and we do the true division between and int64 + # and a constant that will get typed as int8. + # As this is just to allow merging more case, if + # the upcast don't work, we can just skip it. + return None - # Copy over output stacktrace from before upcasting - copy_stack_trace(node.outputs[0], rval) - return rval + # Copy over output stacktrace from before upcasting + copy_stack_trace(old_out, rval) + return [rval] @node_rewriter([add, mul]) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index af42bee236..afee40bb80 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -26,7 +26,7 @@ from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod +from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -282,7 +282,7 @@ def cholesky_ldotlt(fgraph, node): # This rewrite only applies to matrix Dot and A.owner.inputs[0].type.ndim == 2 ) - or (A.owner.op == _matrix_matrix_matmul) + or (A.owner.op == _matmul) ) ): return diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..cc7690ce82 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -28,6 +28,7 @@ as_tensor_variable, cast, constant, + expand_dims, get_underlying_scalar_constant_value, moveaxis, ones_like, @@ -35,7 +36,6 @@ switch, zeros_like, ) -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays @@ -44,10 +44,8 @@ Prod, Sum, _conj, - _inner_prod, - _matrix_matrix_matmul, - _matrix_vec_prod, - _vec_matrix_prod, + _dot, + _matmul, add, digamma, dot, @@ -97,6 +95,7 @@ register_useless, ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift +from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -174,81 +173,154 @@ def local_lift_transpose_through_dot(fgraph, node): These rewrites "lift" (propagate towards the inputs) `DimShuffle` through dot product. It allows to put the graph in a more standard shape, and to later merge consecutive `DimShuffle`\s. - - The transformation should be apply whether or not the transpose is - inplace. The newly-introduced transpositions are not inplace, this will - be taken care of in a later rewrite phase. - """ - if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)): - return False - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + + if not ( + is_matrix_transpose(node.out) + and node.inputs[0].owner + and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul)) + ): return False + x, y = node.inputs[0].owner.inputs - if x.ndim == y.ndim == 2: + if x.ndim >= y.ndim >= 2: # Output is dot product of transposed inputs in reverse order - ret = [dot(y.T, x.T)] + ret = [dot_op(y.mT, x.mT)] # Copy over stack trace to output from result of dot-product copy_stack_trace(node.inputs[0], ret) return ret -@register_stabilize +@register_canonicalize @register_specialize -@node_rewriter(tracks=[Blockwise]) +@node_rewriter(tracks=[_matmul]) def local_batched_matmul_to_core_matmul(fgraph, node): - """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul. + """Move batch dimensions of matmul operands to core matmul - Example, if x has batch dimensions, but y not: + Example, if x has batch dimensions that don't overlap with batch dimensions of y x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1]) - It also works when y has batch dimensions, but x not. + It also works for batch dimensions of y that don't overlap with batch dimensions of x """ - # Check whether we have a matmul operation in this node - if not ( - isinstance(node.op.core_op, Dot) - and len(node.op.inputs_sig[0]) == 2 - and len(node.op.inputs_sig[1]) == 2 - ): - return None - x, y = node.inputs batch_ndim = node.op.batch_ndim(node) - # Check if x has batch dimensions, but y not (or only broadcastable dimensions) - if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( - y.type.broadcastable[:-2] - ): - x_stacked = x.reshape((-1, x.shape[-1])) - out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) - out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) - return [out] - - # Otherwise, check if y has batch dimension, but x not - elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( - x.type.broadcastable[:-2] - ): - # For the y batch case we need to first move the batch axes and then reshape - # y.shape == (*b, k, n) - y_tr = moveaxis(y, -2, 0) # (k, *b, n) - y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n) - out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) - out_stacked_tr = out_stacked.reshape( - (x.shape[-2], *y.shape[:-2], y.shape[-1]) - ) # (m, *b, n) - out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n) - return [out] - - # Both x and y have batch dimensions, nothing to do here - return None + x_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_y and not bcast_x + ] + + y_axis_to_merge = [ + i + for i, (bcast_x, bcast_y) in enumerate( + zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2]) + ) + if bcast_x and not bcast_y + ] + + if not (x_axis_to_merge or y_axis_to_merge): + return None + + x_shape = tuple(x.shape) + y_shape = tuple(y.shape) + x_is_row = x.type.broadcastable[-2] + y_is_col = y.type.broadcastable[-1] + n_x_axis_to_merge = len(x_axis_to_merge) + n_y_axis_to_merge = len(y_axis_to_merge) + n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge + + x_stacked, y_stacked = x, y + dims_were_merged = False + + if n_x_axis_to_merge: + # ravel batch dimensions of x on the core (m) axis + x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2)) + x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination) + if x_is_row: + # x was a row matrix, squeeze it to clean up the graph + x_stacked = x_stacked.squeeze(-2) + if n_x_axis_to_merge > 1 or not x_is_row: + # Ravel moved batch dims together with (m) if needed + x_stacked_shape = tuple(x_stacked.shape) + x_stacked = x_stacked.reshape( + (*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1]) + ) + dims_were_merged = True + + if n_y_axis_to_merge: + # ravel batch dimensions of y on the core (n) axis + y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1)) + y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination) + if y_is_col: + # y was a column matrix, squeeze it to clean up the graph + y_stacked = y_stacked.squeeze(-1) + if n_y_axis_to_merge > 1 or not y_is_col: + # Ravel moved batch dims together with (n) if needed + y_stacked_shape = tuple(y_stacked.shape) + y_stacked = y_stacked.reshape( + (*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1) + ) + dims_were_merged = True + + # Squeeze x_dims corresponding to merged dimensions of y + x_axis_to_squeeze = np.array(y_axis_to_merge) + for i in reversed(x_axis_to_merge): + # The corresponding dimensions of y may have shifted when we merged dimensions of x + x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1 + x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze)) + + # Same for y + y_axis_to_squeeze = np.array(x_axis_to_merge) + for i in reversed(y_axis_to_merge): + y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1 + y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze)) + + out_stacked = x_stacked @ y_stacked + + # Split back any merged dimensions + if dims_were_merged: + x_merged_shapes = [x_shape[i] for i in x_axis_to_merge] + if not x_is_row: + # Otherwise we handle that later with expand_dims, which is cleaner + x_merged_shapes.append(x_shape[-2]) + y_merged_shapes = [y_shape[i] for i in y_axis_to_merge] + if not y_is_col: + # Otherwise we handle that later with expand_dims, which is cleaner + y_merged_shapes.append(y_shape[-1]) + out_stacked_shape = tuple(out_stacked.shape) + out_unstacked = out_stacked.reshape( + ( + *out_stacked_shape[: batch_ndim - n_axis_to_merge], + *x_merged_shapes, + *y_merged_shapes, + ) + ) + else: + out_unstacked = out_stacked + + # Add back dummy row, col axis + # We do this separately to avoid the reshape as much as we can + if y_is_col and (n_y_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -1) + if x_is_row and (n_x_axis_to_merge or dims_were_merged): + out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2) + + # Move batch axis back to their original location + source = range(-n_axis_to_merge - 2, 0) + destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1) + out = moveaxis(out_unstacked, source, destination) + return [out] @register_canonicalize @register_specialize -@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) +@node_rewriter([_matmul]) def local_blockwise_dot_to_mul(fgraph, node): """Rewrite blockwise dots that correspond to multiplication without summation. diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index be16c4fb61..8e9fba22e4 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -351,7 +351,8 @@ def local_useless_slice(fgraph, node): new_idxs[dim] = slice(start, stop, step) if change_flag or ((last_useful_idx + 1) < len(idxs)): - out = x[tuple(new_idxs[: last_useful_idx + 1])] + new_idxs = tuple(new_idxs[: last_useful_idx + 1]) + out = x[new_idxs] if new_idxs else x # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] @@ -369,74 +370,73 @@ def local_subtensor_merge(fgraph, node): """ from pytensor.scan.op import Scan - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner and isinstance(u.owner.op, Subtensor): - # We can merge :) - # x actual tensor on which we are picking slices - x = u.owner.inputs[0] - # slices of the first applied subtensor - slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) - slices2 = get_idx_list(node.inputs, node.op.idx_list) - - # Don't try to do the optimization on do-while scan outputs, - # as it will create a dependency on the shape of the outputs - if ( - x.owner is not None - and isinstance(x.owner.op, Scan) - and x.owner.op.info.as_while - ): - return None + u = node.inputs[0] + if not (u.owner is not None and isinstance(u.owner.op, Subtensor)): + return None - # Get the shapes of the vectors ! - try: - # try not to introduce new shape into the graph - xshape = fgraph.shape_feature.shape_of[x] - ushape = fgraph.shape_feature.shape_of[u] - except AttributeError: - # Following the suggested use of shape_feature which should - # consider the case when the compilation mode doesn't - # include the ShapeFeature - xshape = x.shape - ushape = u.shape - - merged_slices = [] - pos_2 = 0 - pos_1 = 0 - while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): - slice1 = slices1[pos_1] - if isinstance(slice1, slice): - merged_slices.append( - merge_two_slices( - fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] - ) - ) - pos_2 += 1 - else: - merged_slices.append(slice1) - pos_1 += 1 - - if pos_2 < len(slices2): - merged_slices += slices2[pos_2:] - else: - merged_slices += slices1[pos_1:] + # We can merge :) + # x actual tensor on which we are picking slices + x = u.owner.inputs[0] + # slices of the first applied subtensor + slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) + slices2 = get_idx_list(node.inputs, node.op.idx_list) - merged_slices = tuple(as_index_constant(s) for s in merged_slices) - subtens = Subtensor(merged_slices) + # Don't try to do the optimization on do-while scan outputs, + # as it will create a dependency on the shape of the outputs + if ( + x.owner is not None + and isinstance(x.owner.op, Scan) + and x.owner.op.info.as_while + ): + return None - sl_ins = get_slice_elements( - merged_slices, lambda x: isinstance(x, Variable) + # Get the shapes of the vectors ! + try: + # try not to introduce new shape into the graph + xshape = fgraph.shape_feature.shape_of[x] + ushape = fgraph.shape_feature.shape_of[u] + except AttributeError: + # Following the suggested use of shape_feature which should + # consider the case when the compilation mode doesn't + # include the ShapeFeature + xshape = x.shape + ushape = u.shape + + merged_slices = [] + pos_2 = 0 + pos_1 = 0 + while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): + slice1 = slices1[pos_1] + if isinstance(slice1, slice): + merged_slices.append( + merge_two_slices( + fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] + ) ) - # Do not call make_node for test_value - out = subtens(x, *sl_ins) + pos_2 += 1 + else: + merged_slices.append(slice1) + pos_1 += 1 - # Copy over previous output stacktrace - # and stacktrace from previous slicing operation. - # Why? Because, the merged slicing operation could have failed - # because of either of the two original slicing operations - orig_out = node.outputs[0] - copy_stack_trace([orig_out, node.inputs[0]], out) - return [out] + if pos_2 < len(slices2): + merged_slices += slices2[pos_2:] + else: + merged_slices += slices1[pos_1:] + + merged_slices = tuple(as_index_constant(s) for s in merged_slices) + subtens = Subtensor(merged_slices) + + sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable)) + # Do not call make_node for test_value + out = subtens(x, *sl_ins) + + # Copy over previous output stacktrace + # and stacktrace from previous slicing operation. + # Why? Because, the merged slicing operation could have failed + # because of either of the two original slicing operations + orig_out = node.outputs[0] + copy_stack_trace([orig_out, node.inputs[0]], out) + return [out] @register_specialize @@ -787,6 +787,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): if not isinstance(slice1, slice): raise ValueError("slice1 should be of type `slice`") + # Simple case where one of the slices is useless + if is_full_slice(slice1): + return slice2 + elif is_full_slice(slice2): + return slice1 + sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl2, reverse2 = get_canonical_form_slice(slice2, len2) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index eb31514463..5ca7fa5929 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -5,7 +5,7 @@ from pytensor import Variable from pytensor.compile import optdb -from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar import basic as ps @@ -20,6 +20,7 @@ join, register_infer_shape, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import squeeze @@ -118,21 +119,43 @@ def local_subtensor_of_dot(fgraph, node): the remaining entries of ``idxs`` (if any), modified to skip the second-to-last dimension of ``B`` (because dot sums over this dimension). """ - if not isinstance(node.op, Subtensor): - return - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + x, *idx_vars = node.inputs + if not ( + x.owner is not None + and ( + isinstance(x.owner.op, Dot) + or ( + isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, Dot) + ) + ) + ): return # If there is other node that use the outputs of the dot # We don't want to compute twice the sub part. - if len(fgraph.clients[node.inputs[0]]) > 1: + if len(fgraph.clients[x]) > 1: return - a = node.inputs[0].owner.inputs[0] - b = node.inputs[0].owner.inputs[1] + a = x.owner.inputs[0] + b = x.owner.inputs[1] + idx_list = indices_from_subtensor(idx_vars, node.op.idx_list) - idx_list = get_idx_list(node.inputs, node.op.idx_list) + batch_ndim = ( + x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0 + ) - num_a_indices = min(a.ndim - 1, len(idx_list)) + if batch_ndim: + batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:] + if not idx_list: + # Indexing only over batch dimensions of Blockwise, that can be handled by another rewrite + return None + # We perform the rest of the rewrite on dummy a, b that correspond to the core case + a = a.type.clone(shape=a.type.shape[batch_ndim:])() + b = b.type.clone(shape=b.type.shape[batch_ndim:])() + + a_ndim = a.ndim + b_ndim = b.ndim + num_a_indices = min(a_ndim - 1, len(idx_list)) a_indices = idx_list[:num_a_indices] b_indices = idx_list[num_a_indices:] @@ -141,26 +164,22 @@ def local_subtensor_of_dot(fgraph, node): # This wasn't necessary for a, because we just omitted the last index. # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] # (dot also handles b.ndim < 2 as a special case) - if b.ndim > 1 and len(b_indices) >= b.ndim - 1: + if b_ndim > 1 and len(b_indices) >= b_ndim - 1: b_indices = ( - b_indices[: b.ndim - 2] + b_indices[: b_ndim - 2] + (slice(None, None, None),) - + b_indices[b.ndim - 2 :] + + b_indices[b_ndim - 2 :] ) - a_sub = a.__getitem__(tuple(a_indices)) - b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b + a_sub = a[tuple(a_indices)] + b_sub = b[tuple(b_indices)] if b_indices else b + r = dot(a_sub, b_sub) - # Copy over previous output stacktrace to a_sub and b_sub, - # because an error in the subtensor operation (e.g. an index error) - # on either a or b must correspond to an error in the - # subtensor operation on their dot product. - copy_stack_trace(node.outputs[0], [a_sub, b_sub]) + if batch_ndim: + # Replace dummy inputs by the original batch ones + r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]}) + r = r[tuple(batch_idx_list)] - # Copy over previous output stacktrace and previous dot product stacktrace, - # because an error here may correspond to an either in either the original - # dot product, or in the dot product after the subtensor operation. - r = dot(a_sub, b_sub) copy_stack_trace([node.outputs[0], node.inputs[0]], r) return [r] @@ -169,8 +188,8 @@ def local_subtensor_of_dot(fgraph, node): @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Subtensor]) -def local_subtensor_of_elemwise(fgraph, node): - """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior. +def local_subtensor_of_batch_dims(fgraph, node): + """Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior. exp(x)[:, 0] -> exp(x[:, 0]) add(x, y)[0] -> add(x[0], y[0]) @@ -178,7 +197,7 @@ def local_subtensor_of_elemwise(fgraph, node): """ elem, *idx = node.inputs - if not (elem.owner and isinstance(elem.owner.op, Elemwise)): + if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)): return None if len(fgraph.clients[elem]) > 1: @@ -188,9 +207,34 @@ def local_subtensor_of_elemwise(fgraph, node): idx_tuple = indices_from_subtensor(idx, node.op.idx_list) + batch_ndim = ( + elem.owner.op.batch_ndim(elem.owner) + if isinstance(elem.owner.op, Blockwise) + else elem.ndim + ) + + if len(idx_tuple) > batch_ndim: + # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only + batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] + if all(is_full_slice(idx) for idx in batch_indices): + # No batch indices, nothing to do + return None + elem_with_batch_indices = elem[batch_indices] + [elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform( + fgraph, elem_with_batch_indices.owner + ) + # Reapply the core_indices + core_ndim = elem.type.ndim - batch_ndim + # Number of batch dims may have changed with the lifting of indices, so we recompute + new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim + new_indices = (*(slice(None),) * new_batch_ndim, *core_indices) + new_elem = elem_with_batch_indices_lifted[new_indices] + copy_stack_trace(node.outputs[0], new_elem) + return [new_elem] + elem_inputs = elem.owner.inputs - elem_bcast = elem.type.broadcastable - if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): + elem_bcast = elem.type.broadcastable[:batch_ndim] + if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs): # No need to worry about implicit broadcasting. indexed_inputs = [inp[idx_tuple] for inp in elem_inputs] @@ -201,7 +245,7 @@ def local_subtensor_of_elemwise(fgraph, node): zip( idx_tuple, elem_bcast, - *(inp.type.broadcastable for inp in elem_inputs), + *(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs), # Indices can be shorter than input ndims strict=False, ) @@ -435,6 +479,41 @@ def local_subtensor_of_expand_dims(fgraph, node): return [out] +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_squeeze(fgraph, node): + """Lift subtensor through a squeeze operation""" + x, *idxs_vars = node.inputs + if not ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + return None + + [x_before_squeeze] = x.owner.inputs + idxs = indices_from_subtensor(idxs_vars, node.op.idx_list) + dropped_dims = x.owner.op.drop + + # Apply indices directly on x + # Add empty slices on the axis that squeeze would have removed + new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None)) + x_indexed = x_before_squeeze[tuple(new_idxs)] + + # Reapply squeeze + # Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims + new_dropped_dims = np.array(dropped_dims) + for i, new_idx in reversed(tuple(enumerate(new_idxs))): + if not isinstance(new_idx, slice): + # If it's not a slice, it's an integer which drops the dimension + new_dropped_dims[new_dropped_dims > i] -= 1 + new_x = x_indexed.squeeze(tuple(new_dropped_dims)) + + copy_stack_trace(x, new_x) + return [new_x] + + @register_canonicalize @register_specialize @node_rewriter([Subtensor]) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..1e5dc5ba97 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -42,6 +42,7 @@ Prod, Sum, _conj, + _matmul, add, arccosh, arcsinh, @@ -4562,6 +4563,80 @@ def test_local_batched_matmul_to_core_matmul(): np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) +@pytest.mark.parametrize( + "mat_shape, vec_shape", + [ + [(1, 2, 2), (5, 2)], + [(5, 2, 2), (1, 2)], + [(1, 1, 2, 2), (7, 5, 2)], + [(7, 5, 2, 2), (1, 1, 5, 2)], + [(1, 5, 1, 2, 2), (7, 5, 7, 2)], + [(7, 5, 7, 2, 2), (1, 5, 1, 2)], + [(5, 1, 3, 1, 2, 2), (1, 7, 3, 7, 2)], + [(1, 7, 3, 7, 2, 2), (5, 1, 3, 1, 2)], + ], + ids=str, +) +@pytest.mark.parametrize("func", ("matvec", "vecmat", "vecdot")) +def test_batch_matvec_to_matmul(func, mat_shape, vec_shape): + def count_matvec_nodes(graph): + # Counts how many matmul nodes actually correspond to matvec or vecmat + return len( + [ + var + for var in ancestors([graph]) + if ( + var.owner is not None + and var.owner.op == _matmul + and ( + (var.owner.inputs[0].type.shape[-2] == 1) + or (var.owner.inputs[1].type.shape[-1] == 1) + ) + ) + ] + ) + + mat = pt.tensor("mat", shape=mat_shape, dtype="float64") + vec = pt.tensor("vec", shape=vec_shape, dtype="float64") + + if func == "matvec": + out = pt.matvec(mat, vec) + elif func == "vecmat": + out = pt.vecmat(vec, mat) + elif func == "vecdot": + out = pt.vecdot(mat[..., 0], vec) + else: + raise NotImplementedError(func) + + assert count_matvec_nodes(out) == 1 + + rewritten_out = rewrite_graph( + out, exclude=("local_eager_useless_unbatched_blockwise",) + ) + # No `matvec` in the rewritten out if one of the vector can be treated as a matrix + expected = not any( + mat_dim == 1 and vec_dim != 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + if not expected and func == "vecdot": + # In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix + expected = not any( + mat_dim != 1 and vec_dim == 1 + for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2]) + ) + + assert count_matvec_nodes(rewritten_out) == expected + + rng = np.random.default_rng(mat_shape + vec_shape) + eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)} + # Evaluate results are correct without further rewrites + no_optimization = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + rewritten_out.eval(eval_dict, mode=no_optimization), + out.eval(eval_dict, mode=no_optimization), + ) + + def test_log_kv_stabilization(): x = pt.scalar("x") out = log(kv(4.5, x)) diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 933d1a1577..0d79afd367 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -37,14 +37,16 @@ vector, ) from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.subtensor_lift import ( local_subtensor_make_vector, - local_subtensor_of_elemwise, + local_subtensor_of_batch_dims, local_subtensor_shape_constant, ) from pytensor.tensor.shape import SpecifyShape, _shape +from pytensor.tensor.signal import convolve1d from pytensor.tensor.special import softmax from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor @@ -58,7 +60,7 @@ NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) -class TestLocalSubtensorOfElemwise: +class TestLocalSubtensorOfBatchDims: def test_unary_multiple_clients(self): # as test0, but we reuse the output of the elemwise # So we should not lift the subtensor @@ -144,7 +146,7 @@ def test_multinary_multiple_clients(self): ), ], ) - def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): + def test_elemwise(self, original_fn, expected_fn): rng = np.random.default_rng(257) x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) @@ -163,7 +165,7 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): out.eval({x: x_test, y: y_test}, **eval_kwargs), ) - def test_local_subtensor_of_elemwise_multiple_clients(self): + def test_elemwise_multiple_clients(self): x = pt.matrix("x", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3)) out1 = add(x, y) @@ -171,11 +173,42 @@ def test_local_subtensor_of_elemwise_multiple_clients(self): # Rewrite should fail when another node uses out1 directly (in this case it's an extra output) fgraph = FunctionGraph([x, y], [out1, out2], clone=False) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None # Otherwise it should work fgraph.remove_output(0) - assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None + assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None + + def test_blockwise(self): + x = tensor3("x", shape=(7, 5, 11)) + y = tensor("y", shape=(7, 33)) + out = convolve1d(x, y[:, None, :]) + assert isinstance(out.owner.op, Blockwise) + + out_sliced = out[2:][:, 3:] + rewritten_out_sliced = rewrite_graph(out_sliced) + assert equal_computations( + [rewritten_out_sliced], [convolve1d(x[2:, 3:], y[2:][:, None, :])] + ) + + rng = np.random.default_rng(191) + x_test = rng.normal(size=x.type.shape).astype(x.type.dtype) + y_test = rng.normal(size=y.type.shape).astype(y.type.dtype) + np.testing.assert_allclose( + rewritten_out_sliced.eval( + {x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE + ), + out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE), + ) + + # Check slice on core dims + # Note: if we implement a rewrite on the core dims, this test should be changed for another Blockwise + # that has no such rewrite or one created just for testing purposes + out_sliced = out[2:][:, 0][:, 4:] + rewritten_out_sliced = rewrite_graph(out_sliced) + assert equal_computations( + [rewritten_out_sliced], [convolve1d(x[2:, 0], y[2:])[:, 4:]] + ) @pytest.mark.parametrize( diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 5e6271e170..b01a50e2fa 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester): def test_Flatten(self): atens3 = tensor3() atens3_val = random(4, 5, 3) - for ndim in (3, 2, 1): + for ndim in (2, 1): self._compile_and_check( [atens3], [flatten(atens3, ndim)], [atens3_val], Reshape, - excluding=["local_useless_reshape"], ) amat = matrix() amat_val = random(4, 5) - for ndim in (2, 1): - self._compile_and_check( - [amat], - [flatten(amat, ndim)], - [amat_val], - Reshape, - excluding=["local_useless_reshape"], - ) - - avec = vector() - avec_val = random(4) ndim = 1 self._compile_and_check( - [avec], - [flatten(avec, ndim)], - [avec_val], + [amat], + [flatten(amat, ndim)], + [amat_val], Reshape, - excluding=["local_useless_reshape"], ) def test_Eye(self): diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 9673f8338e..8de9c24b18 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -4,6 +4,7 @@ import pytensor import pytensor.tensor.basic as ptb +from pytensor import function from pytensor.configdefaults import config from pytensor.gradient import ( DisconnectedInputError, @@ -31,7 +32,7 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.scan.op import Scan -from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh +from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import RandomStream from pytensor.tensor.type import ( @@ -940,139 +941,225 @@ def test_undefined_grad_opt(): ) -def test_jacobian_vector(): - x = vector() - y = x * 2 - rng = np.random.default_rng(seed=utt.fetch_seed()) +@pytest.mark.parametrize("vectorize", [False, True], ids=lambda x: f"vectorize={x}") +class TestJacobian: + def test_jacobian_vector(self, vectorize): + x = vector() + y = x * 2 + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), np.eye(10) * 2) + + # test when the jacobian is called with a list of two elements + z = vector() + y = x * z + Js = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Js) + vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + vz = rng.uniform(size=(10,)).astype(pytensor.config.floatX) + vJs = f(vx, vz) + evx = np.zeros((10, 10)) + evz = np.zeros((10, 10)) + np.fill_diagonal(evx, vx) + np.fill_diagonal(evz, vz) + assert np.allclose(vJs[0], evz) + assert np.allclose(vJs[1], evx) + + def test_jacobian_matrix(self, vectorize): + x = matrix() + y = 2 * x.sum(axis=0) + rng = np.random.default_rng(seed=utt.fetch_seed()) + ev = np.zeros((10, 10, 10)) + for dx in range(10): + ev[dx, :, dx] = 2.0 + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + assert np.allclose(f(vx), ev) + + # test when the jacobian is called with a list of two elements + z = matrix() + y = (x * z).sum(axis=1) + Js = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Js) + vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + vz = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) + vJs = f(vx, vz) + evx = np.zeros((10, 10, 10)) + evz = np.zeros((10, 10, 10)) + for dx in range(10): + evx[dx, dx, :] = vx[dx, :] + evz[dx, dx, :] = vz[dx, :] + assert np.allclose(vJs[0], evz) + assert np.allclose(vJs[1], evx) + + def test_jacobian_scalar(self, vectorize): + x = scalar() + y = x * 2 + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # test when the jacobian is called with a tensor as wrt + Jx = jacobian(y, x, vectorize=vectorize) + f = function([x], Jx) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when input is a shape (1,) vector -- should still be treated as a scalar + Jx = jacobian(y[None], x) + f = function([x], Jx) + + # Ensure we hit the scalar grad case (doesn't use scan) + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Scan) for node in nodes) + + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a tuple as wrt + Jx = jacobian(y, (x,), vectorize=vectorize) + assert isinstance(Jx, tuple) + f = function([x], Jx[0]) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a list as wrt + Jx = jacobian(y, [x], vectorize=vectorize) + assert isinstance(Jx, list) + f = function([x], Jx[0]) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + + # test when the jacobian is called with a list of two elements + z = scalar() + y = x * z + Jx = jacobian(y, [x, z], vectorize=vectorize) + f = function([x, z], Jx) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vJx = f(vx, vz) + + assert np.allclose(vJx[0], vz) + assert np.allclose(vJx[1], vx) + + @pytest.mark.parametrize("square_jac", [False, True]) + def test_jacobian_matrix_expression(self, vectorize, square_jac): + x = vector("x", shape=(3,)) + y = outer(x, x) + if not square_jac: + y = y[:, 1:] + Jy_wrt_x = jacobian(y, wrt=x, vectorize=vectorize) + f = function([x], Jy_wrt_x) + x_test = np.arange(3, dtype=x.type.dtype) + res = f(x_test) + expected_res = np.array( + [ + # Jy[0]_wrt_x (y[0] = x[0] * x) + [[0, 0, 0], [1, 0, 0], [2, 0, 0]], + # Jy[1]_wrt_x (y[1] = x[1] * x) + [ + [1, 0, 0], + [0, 2, 0], + [0, 2, 1], + ], + # Jy[2]_wrt_x (y[2] = x[2] * x) + [ + [2, 0, 0], + [0, 2, 1], + [0, 0, 4], + ], + ] + ) + if not square_jac: + expected_res = expected_res[:, 1:, :] + np.testing.assert_allclose(res, expected_res) + + def test_jacobian_disconnected_inputs(self, vectorize): + # Test that disconnected inputs are properly handled by jacobian. + s1 = scalar("s1") + s2 = scalar("s2") + jacobian_s = jacobian(1 + s1, s2, disconnected_inputs="ignore") + func_s = function([s2], jacobian_s) + val = np.array(1.0, dtype=config.floatX) + np.testing.assert_allclose(func_s(val), np.zeros(1)) + + v1 = vector("v1") + v2 = vector("v2") + jacobian_v = jacobian( + 1 + v1, v2, disconnected_inputs="ignore", vectorize=vectorize + ) + func_v = function([v1, v2], jacobian_v, on_unused_input="ignore") + val = np.arange(4.0, dtype=pytensor.config.floatX) + np.testing.assert_allclose(func_v(val, val), np.zeros((4, 4))) + + m1 = matrix("m1") + m2 = matrix("m2") + jacobian_m = jacobian( + 1 + m1[1:, 2:], m2, disconnected_inputs="ignore", vectorize=vectorize + ) + func_v = function([m1, m2], jacobian_m, on_unused_input="ignore") + val = np.ones((4, 4), dtype=config.floatX) + np.testing.assert_allclose(func_v(val, val), np.zeros((3, 2, 4, 4))) - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) + def test_benchmark(self, vectorize, benchmark): + x = vector("x", shape=(3,)) + y = outer(x, x) - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) + jac_y = jacobian(y, x, vectorize=vectorize) - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), np.eye(10) * 2) + fn = function([x], jac_y, trust_input=True) + benchmark(fn, np.array([0, 1, 2], dtype=x.type.dtype)) - # test when the jacobian is called with a list of two elements - z = vector() - y = x * z - Js = jacobian(y, [x, z]) - f = pytensor.function([x, z], Js) - vx = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - vz = rng.uniform(size=(10,)).astype(pytensor.config.floatX) - vJs = f(vx, vz) - evx = np.zeros((10, 10)) - evz = np.zeros((10, 10)) - np.fill_diagonal(evx, vx) - np.fill_diagonal(evz, vz) - assert np.allclose(vJs[0], evz) - assert np.allclose(vJs[1], evx) - - -def test_jacobian_matrix(): - x = matrix() - y = 2 * x.sum(axis=0) - rng = np.random.default_rng(seed=utt.fetch_seed()) - ev = np.zeros((10, 10, 10)) - for dx in range(10): - ev[dx, :, dx] = 2.0 - - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - assert np.allclose(f(vx), ev) - - # test when the jacobian is called with a list of two elements - z = matrix() - y = (x * z).sum(axis=1) - Js = jacobian(y, [x, z]) - f = pytensor.function([x, z], Js) - vx = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - vz = rng.uniform(size=(10, 10)).astype(pytensor.config.floatX) - vJs = f(vx, vz) - evx = np.zeros((10, 10, 10)) - evz = np.zeros((10, 10, 10)) - for dx in range(10): - evx[dx, dx, :] = vx[dx, :] - evz[dx, dx, :] = vz[dx, :] - assert np.allclose(vJs[0], evz) - assert np.allclose(vJs[1], evx) - - -def test_jacobian_scalar(): - x = scalar() - y = x * 2 - rng = np.random.default_rng(seed=utt.fetch_seed()) - - # test when the jacobian is called with a tensor as wrt - Jx = jacobian(y, x) - f = pytensor.function([x], Jx) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when input is a shape (1,) vector -- should still be treated as a scalar - Jx = jacobian(y[None], x) - f = pytensor.function([x], Jx) - - # Ensure we hit the scalar grad case (doesn't use scan) - nodes = f.maker.fgraph.apply_nodes - assert not any(isinstance(node.op, Scan) for node in nodes) - - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a tuple as wrt - Jx = jacobian(y, (x,)) - assert isinstance(Jx, tuple) - f = pytensor.function([x], Jx[0]) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a list as wrt - Jx = jacobian(y, [x]) - assert isinstance(Jx, list) - f = pytensor.function([x], Jx[0]) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - assert np.allclose(f(vx), 2) - - # test when the jacobian is called with a list of two elements - z = scalar() - y = x * z - Jx = jacobian(y, [x, z]) - f = pytensor.function([x, z], Jx) - vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) - vJx = f(vx, vz) - - assert np.allclose(vJx[0], vz) - assert np.allclose(vJx[1], vx) + def test_benchmark_partial_jacobian(self, vectorize, benchmark): + # Example from https://github.com/jax-ml/jax/discussions/5904#discussioncomment-422956 + N = 1000 + rng = np.random.default_rng(2025) + x_test = rng.random((N,)) + + f_mat = rng.random((N, N)) + x = vector("x", dtype="float64") + + def f(x): + return sqrt(f_mat @ x / N) + + full_jacobian = jacobian(f(x), x, vectorize=vectorize) + partial_jacobian = full_jacobian[:5, :5] + + f = pytensor.function([x], partial_jacobian, trust_input=True) + benchmark(f, x_test) def test_hessian(): @@ -1084,25 +1171,7 @@ def test_hessian(): assert np.allclose(f(vx), np.eye(10) * 2) -def test_jacobian_disconnected_inputs(): - # Test that disconnected inputs are properly handled by jacobian. - - v1 = vector() - v2 = vector() - jacobian_v = pytensor.gradient.jacobian(1 + v1, v2, disconnected_inputs="ignore") - func_v = pytensor.function([v1, v2], jacobian_v) - val = np.arange(4.0).astype(pytensor.config.floatX) - assert np.allclose(func_v(val, val), np.zeros((4, 4))) - - s1 = scalar() - s2 = scalar() - jacobian_s = pytensor.gradient.jacobian(1 + s1, s2, disconnected_inputs="ignore") - func_s = pytensor.function([s2], jacobian_s) - val = np.array(1.0).astype(pytensor.config.floatX) - assert np.allclose(func_s(val), np.zeros(1)) - - -class TestHessianVectorProdudoct: +class TestHessianVectorProduct: def test_rosen(self): x = vector("x", dtype="float64") rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()