diff --git a/pyproject.toml b/pyproject.toml index 9dd60e12dd..2e892535c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,7 +163,7 @@ lines-after-imports = 2 "tests/link/numba/**/test_*.py" = ["E402"] "tests/link/pytorch/**/test_*.py" = ["E402"] "tests/link/mlx/**/test_*.py" = ["E402"] -"tests/xtensor/**/test_*.py" = ["E402"] +"tests/xtensor/**/*.py" = ["E402"] diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index bb49245ebe..ea19cc2882 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -208,19 +208,13 @@ def toposort_key( @singledispatch -def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply: +def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]: # Default implementation is provided in pytensor.tensor.blockwise raise NotImplementedError -def vectorize_node(node: Apply, *batched_inputs) -> Apply: - """Returns vectorized version of node with new batched inputs.""" - op = node.op - return _vectorize_node(op, node, *batched_inputs) - - def _vectorize_not_needed(op, node, *batched_inputs): - return op.make_node(*batched_inputs) + return op.make_node(*batched_inputs).outputs @overload @@ -289,19 +283,37 @@ def vectorize_graph( # [array([-10., -11.]), array([10., 11.])] """ + # TODO: Move this to tensor.vectorize, and make this helper type agnostic. + # + # This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types + # The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels + # + # xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't, + # as it is by design unaware of xtensors and their semantics. if isinstance(outputs, Sequence): seq_outputs = outputs else: seq_outputs = [outputs] + if not all( + isinstance(key, Variable) and isinstance(value, Variable) + for key, value in replace.items() + ): + raise ValueError(f"Some of the replaced items are not Variables: {replace}") + inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys()) new_inputs = [replace.get(inp, inp) for inp in inputs] vect_vars = dict(zip(inputs, new_inputs, strict=True)) for node in toposort(seq_outputs, blockers=inputs): vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] - vect_node = vectorize_node(node, *vect_inputs) - for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): + vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs) + if isinstance(vect_node_or_outputs, Apply): + # Compatibility with the old API + vect_outputs = vect_node_or_outputs.outputs + else: + vect_outputs = vect_node_or_outputs + for output, vect_output in zip(node.outputs, vect_outputs, strict=True): if output in vect_vars: # This can happen when some outputs of a multi-output node are given a replacement, # while some of the remaining outputs are still needed in the graph. diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 650d8dc54c..2400750dab 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,7 +1,7 @@ from pytensor.compile.mode import optdb from pytensor.graph import Constant, Op, node_rewriter from pytensor.graph.destroyhandler import inplace_candidates -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import _vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft @@ -38,7 +38,7 @@ def local_useless_blockwise(fgraph, node): op = node.op inputs = node.inputs dummy_core_node = op._create_dummy_core_node(node.inputs) - vect_node = vectorize_node(dummy_core_node, *inputs) + vect_node = _vectorize_node(dummy_core_node.op, dummy_core_node, *inputs) if not isinstance(vect_node.op, Blockwise): return copy_stack_trace(node.outputs, vect_node.outputs) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index e7fcdbdf3a..143d74d52c 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node): def local_blockwise_inc_subtensor(fgraph, node): """Rewrite blockwised inc_subtensors. - Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch + Note: The reason we don't apply this rewrite eagerly in the `_vectorize_node` dispatch Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y), diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 91f81f3dca..edde022ced 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -17,6 +17,9 @@ def perform(self, node, inputs, outputs): def do_constant_folding(self, fgraph, node): return False + def vectorize_node(self, node, *new_inputs, new_dim: str | None): + raise NotImplementedError(f"Vectorized node not implemented for {self}") + class XTypeCastOp(TypeCastingOp): """Base class for Ops that type cast between TensorType and XTensorType. @@ -27,6 +30,9 @@ class XTypeCastOp(TypeCastingOp): def infer_shape(self, fgraph, node, input_shapes): return input_shapes + def vectorize_node(self, node, *new_inputs, new_dim: str | None): + raise NotImplementedError(f"Vectorized node not implemented for {self}") + class TensorFromXTensor(XTypeCastOp): __props__ = () @@ -42,6 +48,12 @@ def L_op(self, inputs, outs, g_outs): [g_out] = g_outs return [xtensor_from_tensor(g_out, dims=x.type.dims)] + def vectorize_node(self, node, new_x, new_dim): + [old_x] = node.inputs + # We transpose batch dims to the left, for consistency with tensor vectorization + new_x = new_x.transpose(..., *old_x.dims) + return [self(new_x)] + tensor_from_xtensor = TensorFromXTensor() @@ -63,6 +75,18 @@ def L_op(self, inputs, outs, g_outs): [g_out] = g_outs return [tensor_from_xtensor(g_out)] + def vectorize_node(self, node, new_x, new_dim): + [old_x] = node.inputs + if new_x.ndim != old_x.ndim: + if new_dim is None: + raise NotImplementedError( + f"Vectorization of {self} is not well defined because it can't infer the new dimension labels. " + f"Use pytensor.xtensor.vectorization.vectorize_graph instead." + ) + return [type(self)(dims=(new_dim, *self.dims))(new_x)] + else: + return [self(new_x)] + def xtensor_from_tensor(x, dims, name=None): return XTensorFromTensor(dims=dims)(x, name=name) @@ -85,6 +109,16 @@ def L_op(self, inputs, outs, g_outs): [g_out] = g_outs return [rename(g_out, dims=x.type.dims)] + def vectorize_node(self, node, new_x, new_dim): + [old_x] = node.inputs + old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True)) + + # new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed + new_dims = tuple( + old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims + ) + return [type(self)(new_dims)(new_x)] + def rename(x, name_dict: dict[str, str] | None = None, **names: str): if name_dict is not None: diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py index 427a543d4c..ca3591ba7f 100644 --- a/pytensor/xtensor/indexing.py +++ b/pytensor/xtensor/indexing.py @@ -4,6 +4,7 @@ # https://numpy.org/neps/nep-0021-advanced-indexing.html # https://docs.xarray.dev/en/latest/user-guide/indexing.html # https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html +from itertools import chain from typing import Literal from pytensor.graph.basic import Apply, Constant, Variable @@ -11,6 +12,7 @@ from pytensor.tensor.basic import as_tensor from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice from pytensor.xtensor.basic import XOp, xtensor_from_tensor +from pytensor.xtensor.shape import broadcast from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor @@ -195,6 +197,15 @@ def combine_dim_info(idx_dim, idx_dim_shape): output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) return Apply(self, [x, *idxs], [output]) + def vectorize_node(self, node, new_x, *new_idxs, new_dim): + # new_x may have dims in different order + # we pair each pre-existing dim to the respective index + # with new dims having simply a slice(None) + old_x, *_ = node.inputs + dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=True)) + new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims) + return [self(new_x, *new_idxs)] + index = Index() @@ -226,6 +237,22 @@ def make_node(self, x, y, *idxs): out = x.type() return Apply(self, [x, y, *idxs], [out]) + def vectorize_node(self, node, *new_inputs): + # If y or the indices have new dimensions we need to broadcast_x + exclude: set[str] = set( + chain.from_iterable(old_inp.dims for old_inp in node.inputs) + ) + for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True): + # Note: This check may be too conservative + if invalid_new_dims := ((set(new_inp.dims) - set(old_inp.dims)) & exclude): + raise NotImplementedError( + f"Vectorize of {self} is undefined because one of the inputs {new_inp} new dimensions " + f"was present in the old inputs: {sorted(invalid_new_dims)}" + ) + new_x, *_ = broadcast(*new_inputs, exclude=tuple(exclude)) + _, new_y, *new_idxs = new_inputs + return self.make_node(new_x, new_y, *new_idxs) + index_assignment = IndexUpdate("set") index_increment = IndexUpdate("inc") diff --git a/pytensor/xtensor/reduction.py b/pytensor/xtensor/reduction.py index f3d1677ae1..13ae08eafb 100644 --- a/pytensor/xtensor/reduction.py +++ b/pytensor/xtensor/reduction.py @@ -46,6 +46,9 @@ def make_node(self, x): output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) return Apply(self, [x], [output]) + def vectorize_node(self, node, new_x, new_dim): + return [self(new_x)] + def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]: if isinstance(dim, str): @@ -117,6 +120,9 @@ def make_node(self, x): out = x.type() return Apply(self, [x], [out]) + def vectorize_node(self, node, new_x, new_dim): + return [self(new_x)] + def cumreduce(x, dim: REDUCE_DIM, *, binary_op): x = as_xtensor(x) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index ba564feb4f..1d3d607d97 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -68,6 +68,9 @@ def make_node(self, x): ) return Apply(self, [x], [output]) + def vectorize_node(self, node, new_x, new_dim): + return [self(new_x)] + def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]): if dim is not None: @@ -146,6 +149,14 @@ def make_node(self, x, *unstacked_length): ) return Apply(self, [x, *unstacked_lengths], [output]) + def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim): + new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length] + if not all(ul.type.ndim == 0 for ul in new_unstacked_length): + raise NotImplementedError( + f"Vectorization of {self} with batched unstacked_length not implemented, " + ) + return [self(new_x, *new_unstacked_length)] + def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]): if dim is not None: @@ -189,6 +200,11 @@ def make_node(self, x): ) return Apply(self, [x], [output]) + def vectorize_node(self, node, new_x, new_dim): + old_dims = self.dims + new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims) + return [type(self)(dims=(*new_dims, *old_dims))(new_x)] + def transpose( x, @@ -302,6 +318,9 @@ def make_node(self, *inputs): output = xtensor(dtype=dtype, dims=dims, shape=shape) return Apply(self, inputs, [output]) + def vectorize_node(self, node, *new_inputs, new_dim): + return [self(*new_inputs)] + def concat(xtensors, dim: str): """Concatenate a sequence of XTensorVariables along a specified dimension. @@ -383,6 +402,9 @@ def make_node(self, x): ) return Apply(self, [x], [out]) + def vectorize_node(self, node, new_x, new_dim): + return [self(new_x)] + def squeeze(x, dim: str | Sequence[str] | None = None): """Remove dimensions of size 1 from an XTensorVariable.""" @@ -442,6 +464,14 @@ def make_node(self, x, size): ) return Apply(self, [x, size], [out]) + def vectorize_node(self, node, new_x, new_size, new_dim): + new_size = new_size.squeeze() + if new_size.type.ndim != 0: + raise NotImplementedError( + f"Vectorization of {self} with batched new_size not implemented, " + ) + return self.make_node(new_x, new_size) + def expand_dims(x, dim=None, axis=None, **dim_kwargs): """Add one or more new dimensions to an XTensorVariable.""" @@ -537,6 +567,19 @@ def make_node(self, *inputs): return Apply(self, inputs, outputs) + def vectorize_node(self, node, *new_inputs, new_dim): + if exclude_set := set(self.exclude): + for new_x, old_x in zip(node.inputs, new_inputs, strict=True): + if invalid_excluded := ( + (set(new_x.dims) - set(old_x.dims)) & exclude_set + ): + raise NotImplementedError( + f"Vectorize of {self} is undefined because one of the inputs {new_x} " + f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before." + ) + + return self(*new_inputs, return_list=True) + def broadcast( *args, exclude: str | Sequence[str] | None = None diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 98b541083a..664df1abcd 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None) if isinstance(x, Variable): if isinstance(x.type, XTensorType): - if (dims is None) or (x.type.dims == dims): + if (dims is None) or (x.type.dims == tuple(dims)): return x else: raise ValueError( diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index ec0033c542..962a73b41f 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -1,21 +1,31 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence +from functools import singledispatch from itertools import chain +from typing import Literal import numpy as np +from pytensor import Variable, shared from pytensor import scalar as ps -from pytensor import shared from pytensor.graph import Apply, Op +from pytensor.graph.replace import _vectorize_node, graph_replace +from pytensor.graph.traversal import toposort, truncated_graph_inputs +from pytensor.graph.type import HasShape from pytensor.scalar import discrete_dtypes -from pytensor.tensor import tensor +from pytensor.tensor import ( + TensorVariable, + broadcast_shape, + broadcast_to, + tensor, +) from pytensor.tensor.random.op import RNGConsumerOp from pytensor.tensor.random.type import RandomType from pytensor.tensor.utils import ( get_static_shape_from_size_variables, ) from pytensor.utils import unzip -from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor +from pytensor.xtensor.basic import XOp, XTypeCastOp +from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor def combine_dims_and_shape( @@ -69,6 +79,9 @@ def make_node(self, *inputs): ] return Apply(self, inputs, outputs) + def vectorize_node(self, node, *new_inputs, new_dim): + return self(*new_inputs, return_list=True) + class XBlockwise(XOp): __props__ = ("core_op", "core_dims") @@ -136,6 +149,9 @@ def make_node(self, *inputs): ] return Apply(self, inputs, outputs) + def vectorize_node(self, node, *new_inputs, new_dim): + return self(*new_inputs, return_list=True) + class XRV(XOp, RNGConsumerOp): """Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics. @@ -283,3 +299,278 @@ def make_node(self, rng, *extra_dim_lengths_and_params): ) return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out]) + + def vectorize_node(self, node, *new_inputs, new_dim): + if len(new_inputs) != len(node.inputs): + raise NotImplementedError( + f"Vectorization of {self} with additional extra_dim_lengths not implemented, " + "as it can't infer new dimension labels" + ) + new_rng, *new_extra_dim_lengths_and_params = new_inputs + new_extra_dim_lengths, new_params = ( + new_extra_dim_lengths_and_params[: len(self.extra_dims)], + new_extra_dim_lengths_and_params[len(self.extra_dims) :], + ) + + new_extra_dim_lengths = [dl.squeeze() for dl in new_extra_dim_lengths] + if not all(dl.type.ndim == 0 for dl in new_extra_dim_lengths): + raise NotImplementedError( + f"Vectorization of {self} with batched extra_dim_lengths not implemented, " + ) + + return self.make_node(new_rng, *new_extra_dim_lengths, *new_params).outputs + + +@_vectorize_node.register(XOp) +@_vectorize_node.register(XTypeCastOp) +def vectorize_xop(op, node, *new_inputs) -> Apply: + # This gets called by regular graph_replace, which isn't aware of xtensor and doesn't have a concept of `new_dim` + return vectorize_xnode(node.op, node, *new_inputs, new_dim=None) + + +@singledispatch +def vectorize_xnode( + op: XOp | XTypeCastOp, + node: Apply, + *batched_inputs: Variable, + new_dim: str | None = None, +) -> tuple[Variable]: + """Returns vectorized version of node with new batched inputs.""" + + all_old_dims_set = set( + chain.from_iterable( + x.dims + for x in (*node.inputs, *node.outputs) + if isinstance(x.type, XTensorType) + ) + ) + for new_inp, old_inp in zip(batched_inputs, node.inputs, strict=True): + if not ( + isinstance(new_inp.type, XTensorType) + and isinstance(old_inp.type, XTensorType) + ): + continue + + old_dims_set = set(old_inp.dims) + new_dims_set = set(new_inp.dims) + + # Validate that new inputs didn't drop pre-existing dims + if missing_dims := old_dims_set - new_dims_set: + raise ValueError( + f"Vectorized input {new_inp} is missing pre-existing dims: {sorted(missing_dims)}" + ) + # Or have new dimensions that were already in the graph + if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set): + raise ValueError( + f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}" + ) + + def align_dims(new_x, old_x): + if isinstance(new_x.type, XTensorType): + if new_dim is not None and new_dim in new_x.dims: + return new_x.transpose(new_dim, *old_x.dims) + else: + return new_x.transpose(..., *old_x.dims) + else: + return new_x + + vectorized_outs = op.vectorize_node( + node, + *( + align_dims(new_x, old_x) + for new_x, old_x in zip(batched_inputs, node.inputs) + ), + new_dim=new_dim, + ) + + return tuple( + align_dims(new_out, old_out) + for new_out, old_out in zip(vectorized_outs, node.outputs, strict=True) + ) + + +def _vectorize_single_dim(outputs, replace, new_dim: str): + inputs = truncated_graph_inputs(outputs, ancestors_to_include=replace.keys()) + new_inputs = [replace.get(inp, inp) for inp in inputs] + + vect_vars = dict(zip(inputs, new_inputs, strict=True)) + for node in toposort(outputs, blockers=inputs): + vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] + + if isinstance(node.op, XOp | XTypeCastOp): + node_vect_outs = vectorize_xnode( + node.op, node, *vect_inputs, new_dim=new_dim + ) + else: + node_vect_outs = _vectorize_node(node.op, node, *vect_inputs) + if isinstance(node_vect_outs, Apply): + # Old API + node_vect_outs = node_vect_outs.outputs + + for output, vect_output in zip(node.outputs, node_vect_outs, strict=True): + if output in vect_vars: + # This can happen when some outputs of a multi-output node are given a replacement, + # while some of the remaining outputs are still needed in the graph. + # We make sure we don't overwrite the provided replacement with the newly vectorized output + continue + vect_vars[output] = vect_output + + return [vect_vars[out] for out in outputs] + + +def vectorize_graph( + outputs: Variable | Sequence[Variable], + replace: Mapping[Variable, Variable], + *, + new_tensor_dims: Sequence[str] = (), +): + """Dimension-aware vectorize_graph. + + This is an extension to `pytensor.graph.replace.vectorize_graph` that correctly handles mixed XTensor/TensorVariable graphs. + Vectorization rule for batch TensorVariables works like regular `vectorize_graph`, with batched axes assumed to be aligned + positionally and present on the left of the new inputs. They must be given labels with `new_tensor_dims` argument (left to right), + for correct interaction with XTensorVariables. + + Batched XTensorVariables may contain new dimensions anywhere. + These can include dimensions in `new_tensor_dims`, as well as other new dimensions observable in XTensorVariable.dims. + New dimensions for a given input should not have existed in the original graph. + """ + if isinstance(outputs, Sequence): + seq_outputs = outputs + else: + seq_outputs = [outputs] + + if not all( + isinstance(key, Variable) and isinstance(value, Variable) + for key, value in replace.items() + ): + raise ValueError(f"Some of the replaced items are not Variables: {replace}") + + # Collect new dimensions and sizes, and validate + new_xtensor_sizes: dict[str, TensorVariable] = {} + new_tensor_sizes: list[tuple[TensorVariable | Literal[1], ...]] = [] + for old, new in replace.items(): + if isinstance(new.type, XTensorType): + old_var_dims_set = set(old.dims) + new_var_dims_set = set(new.dims) + if missing_dims := old_var_dims_set - new_var_dims_set: + raise ValueError( + f"Vectorized input {new} is missing pre-existing dims: {sorted(missing_dims)}" + ) + new_xtensor_sizes.update( + {d: s for d, s in new.sizes.items() if d not in old_var_dims_set} + ) + elif isinstance(new, TensorVariable): + n_new_dims = new.type.ndim - old.type.ndim + if n_new_dims < 0: + raise ValueError( + f"Vectorized input {new} is missing pre-existing dims {new.ndim=}, {old.ndim=}" + ) + if n_new_dims > len(new_tensor_dims): + if new_tensor_dims: + raise ValueError( + f"TensorVariable replacement {new} has {n_new_dims} batch dimensions. " + f"You must specify `new_tensor_dims` to label these." + ) + else: + raise ValueError( + f"TensorVariable replacement {new} has {n_new_dims} batch dimensions " + f"but only {new_tensor_dims=} were specified. " + ) + new_tensor_sizes.append( + tuple( + 1 if b else s + for s, b in zip( + tuple(new.shape)[:n_new_dims], + new.type.broadcastable[:n_new_dims], + ) + ) + ) + + elif isinstance(new.type, HasShape) and new.type.ndim != old.type.ndim: + raise NotImplementedError( + f"vectorize_graph does not know how to handle batched input {new} of type {new.type}" + ) + + # Align xtensor batch dimensions on the left, and broadcast tensor batch dimensions + new_dims = ( + *(dim for dim in new_xtensor_sizes if dim not in new_tensor_dims), + *new_tensor_dims, + ) + + if not new_dims: + return graph_replace(outputs, replace, strict=False) + + # Create a mapping from new_tensor_dims -> broadcasted shape from tensors + if new_tensor_dims: + new_tensor_sizes = broadcast_shape(*new_tensor_sizes, arrays_are_shapes=True) + if len(new_tensor_sizes) != len(new_tensor_dims): + raise ValueError( + f"{len(new_tensor_dims)} tensor dims were specified, but only {len(new_tensor_sizes)} were found in the new inputs" + ) + new_tensor_sizes = dict(zip(new_tensor_dims, new_tensor_sizes)) + else: + new_tensor_sizes = {} + # Give preference to xtensor sizes as that doesn't require any broadcasting + new_sizes = tuple( + new_xtensor_sizes.get(dim, new_tensor_sizes.get(dim, 1)) for dim in new_dims + ) + # Align batch dimensions on the left (*xtensor_unique_batch_dims, *tensor_batch_dims, ...) + # We broadcast tensor batch dims as they may have been length 1 + aligned_replace = {} + for old, new in replace.items(): + if isinstance(new, XTensorVariable): + new = new.transpose(*new_dims, ..., missing_dims="ignore") + elif isinstance(new, TensorVariable): + new = broadcast_to( + new, shape=(*new_sizes, *tuple(new.shape)[-len(new_sizes) :]) + ) + aligned_replace[old] = new + del replace + + seq_vect_outputs = seq_outputs + remaining_new_dims = list(new_dims) + while remaining_new_dims: + new_dim = remaining_new_dims.pop() + + if remaining_new_dims: + # We need to use a dummy inputs to batch graph once at a time + # We drop all the dims that are still in `remaining_new_dims` + # Create a mapping: original -> intermediate_batched + single_dim_replace = {} + for old, new in aligned_replace.items(): + n_remaining_dims = len(remaining_new_dims) + if isinstance(new, XTensorVariable): + intermediate_dims, intermediate_shape = unzip( + ( + (d, s) + for d, s in zip(new.type.dims, new.type.shape) + if d not in remaining_new_dims + ), + n=2, + ) + intermediate_type = new.type.clone( + dims=intermediate_dims, shape=intermediate_shape + ) + elif isinstance(new, TensorVariable): + intermediate_type = new.type.clone( + shape=new.type.shape[n_remaining_dims:] + ) + else: + intermediate_type = new.type + single_dim_replace[old] = intermediate_type() + # Updated aligned replace mapping: intermediate_batched -> final_batched + aligned_replace = dict( + zip(single_dim_replace.values(), aligned_replace.values()) + ) + else: + single_dim_replace = aligned_replace + seq_vect_outputs = _vectorize_single_dim( + seq_vect_outputs, single_dim_replace, new_dim + ) + + if isinstance(outputs, Sequence): + return seq_vect_outputs + else: + [vect_output] = seq_vect_outputs + return vect_output diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index f0d64ee76b..94ddfa1771 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -6,10 +6,10 @@ from pytensor import config, function, shared from pytensor.graph.basic import equal_computations from pytensor.graph.replace import ( + _vectorize_node, clone_replace, graph_replace, vectorize_graph, - vectorize_node, ) from pytensor.graph.traversal import graph_inputs from pytensor.tensor import dvector, fvector, vector @@ -277,7 +277,7 @@ def test_multi_output_node(self): # Cases where either x or both of y1 and y2 are given replacements new_out = vectorize_graph(out, {x: new_x}) - expected_new_out = pt.add(*vectorize_node(node, new_x).outputs) + expected_new_out = pt.add(*_vectorize_node(node.op, node, new_x).outputs) assert equal_computations([new_out], [expected_new_out]) new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2}) @@ -291,7 +291,9 @@ def test_multi_output_node(self): # Special case where x is given a replacement as well as only one of y1 and y2 # The graph combines the replaced variable with the other vectorized output new_out = vectorize_graph(out, {x: new_x, y1: new_y1}) - expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1]) + expected_new_out = pt.add( + new_y1, _vectorize_node(node.op, node, new_x).outputs[1] + ) assert equal_computations([new_out], [expected_new_out]) def test_multi_output_node_random_variable(self): @@ -324,3 +326,35 @@ def test_multi_output_node_random_variable(self): new_beta1 * pt.exp(new_beta0 + 1), ] assert equal_computations(new_outs, expected_new_outs) + + def test_non_variable_raises(self): + x = pt.scalar("x", dtype=int) + y = pt.scalar("y", dtype=int) + non_variable_shape = (x, y) + variable_shape = pt.as_tensor(non_variable_shape) + + non_variable_shape_out = pt.zeros(non_variable_shape) + variable_shape_out = pt.zeros(variable_shape) + + non_variable_batch_shape = (non_variable_shape, non_variable_shape) + variable_batch_shape = pt.stacklists(non_variable_batch_shape) + + msg = r"Some of the replaced items are not Variables" + with pytest.raises(ValueError, match=msg): + vectorize_graph( + non_variable_shape_out, {non_variable_shape: non_variable_batch_shape} + ) + + with pytest.raises(ValueError, match=msg): + vectorize_graph( + variable_shape_out, {variable_shape: non_variable_batch_shape} + ) + + batch_out = vectorize_graph( + variable_shape_out, {variable_shape: variable_batch_shape} + ) + assert batch_out.type.shape == (2, None, None) + np.testing.assert_array_equal( + batch_out.eval({x: 3, y: 4}), + np.zeros((2, 3, 4)), + ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 7bef3f759f..2b48287f5a 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -11,7 +11,7 @@ from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.gradient import grad from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph -from pytensor.graph.replace import vectorize_graph, vectorize_node +from pytensor.graph.replace import _vectorize_node, vectorize_graph from pytensor.link.numba import NumbaLinker from pytensor.raise_op import assert_op from pytensor.tensor import ( @@ -95,8 +95,8 @@ def test_vectorize_blockwise(): tns = tensor(shape=(None, None, None)) # Something that falls back to Blockwise - node = MatrixInverse()(mat).owner - vect_node = vectorize_node(node, tns) + out = MatrixInverse()(mat) + vect_node = vectorize_graph(out, {mat: tns}).owner assert isinstance(vect_node.op, Blockwise) and isinstance( vect_node.op.core_op, MatrixInverse ) @@ -105,7 +105,7 @@ def test_vectorize_blockwise(): # Useless blockwise tns4 = tensor(shape=(5, None, None, None)) - new_vect_node = vectorize_node(vect_node, tns4) + new_vect_node = vectorize_graph(vect_node.out, {tns: tns4}).owner assert new_vect_node.op is vect_node.op assert isinstance(new_vect_node.op, Blockwise) and isinstance( new_vect_node.op.core_op, MatrixInverse @@ -204,7 +204,7 @@ def test_vectorize_node_default_signature(): mat = tensor(shape=(5, None)) node = my_test_op.make_node(vec, mat) - vect_node = vectorize_node(node, mat, mat) + vect_node = _vectorize_node(node.op, node, mat, mat) assert isinstance(vect_node.op, Blockwise) and isinstance( vect_node.op.core_op, MyTestOp ) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 41ec020e3e..8bd35bca28 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -16,7 +16,7 @@ from pytensor.compile.mode import Mode, get_default_mode from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.numba import NumbaLinker @@ -1042,46 +1042,39 @@ def test_elemwise(self): vec = tensor(shape=(None,)) mat = tensor(shape=(None, None)) - node = exp(vec).owner - vect_node = vectorize_node(node, mat) - assert vect_node.op == exp - assert vect_node.inputs[0] is mat + out = exp(vec) + vect_out = vectorize_graph(out, {vec: mat}) + assert vect_out.owner.op == exp + assert vect_out.owner.inputs[0] is mat def test_dimshuffle(self): - vec = tensor(shape=(None,)) - mat = tensor(shape=(None, None)) - - node = exp(vec).owner - vect_node = vectorize_node(node, mat) - assert vect_node.op == exp - assert vect_node.inputs[0] is mat - col_mat = tensor(shape=(None, 1)) tcol_mat = tensor(shape=(None, None, 1)) - node = col_mat.dimshuffle(0).owner # drop column - vect_node = vectorize_node(node, tcol_mat) - assert isinstance(vect_node.op, DimShuffle) - assert vect_node.op.new_order == (0, 1) - assert vect_node.inputs[0] is tcol_mat - assert vect_node.outputs[0].type.shape == (None, None) + + out = col_mat.dimshuffle(0) # drop column + vect_out = vectorize_graph(out, {col_mat: tcol_mat}) + assert isinstance(vect_out.owner.op, DimShuffle) + assert vect_out.owner.op.new_order == (0, 1) + assert vect_out.owner.inputs[0] is tcol_mat + assert vect_out.owner.outputs[0].type.shape == (None, None) def test_CAReduce(self): mat = tensor(shape=(None, None)) tns = tensor(shape=(None, None, None)) - node = pt_sum(mat).owner - vect_node = vectorize_node(node, tns) - assert isinstance(vect_node.op, Sum) - assert vect_node.op.axis == (1, 2) - assert vect_node.inputs[0] is tns + out = pt_sum(mat) + vect_out = vectorize_graph(out, {mat: tns}) + assert isinstance(vect_out.owner.op, Sum) + assert vect_out.owner.op.axis == (1, 2) + assert vect_out.owner.inputs[0] is tns bool_mat = tensor(dtype="bool", shape=(None, None)) bool_tns = tensor(dtype="bool", shape=(None, None, None)) - node = pt_any(bool_mat, axis=-2).owner - vect_node = vectorize_node(node, bool_tns) - assert isinstance(vect_node.op, Any) - assert vect_node.op.axis == (1,) - assert vect_node.inputs[0] is bool_tns + out = pt_any(bool_mat, axis=-2) + vect_out = vectorize_graph(out, {bool_mat: bool_tns}) + assert isinstance(vect_out.owner.op, Any) + assert vect_out.owner.op.axis == (1,) + assert vect_out.owner.inputs[0] is bool_tns def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 9376f3ded4..1a8f02cb87 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -21,7 +21,7 @@ from pytensor.gradient import NullTypeGradError, grad, numeric_grad from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph from pytensor.graph.traversal import ancestors, applys_between from pytensor.link.c.basic import DualLinker from pytensor.link.numba import NumbaLinker @@ -1070,11 +1070,10 @@ def test_vectorize(self, core_axis, batch_axis): argmax_x = argmax(x, axis=core_axis) - arg_max_node = argmax_x.owner - new_node = vectorize_node(arg_max_node, batch_x) + vect_out = vectorize_graph(argmax_x, {x: batch_x}) - assert isinstance(new_node.op, Argmax) - assert new_node.op.axis == batch_axis + assert isinstance(vect_out.owner.op, Argmax) + assert vect_out.owner.op.axis == batch_axis class TestArgminArgmax: diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 47965b0b33..517c8ad442 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -8,7 +8,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config from pytensor.graph.basic import Variable, equal_computations -from pytensor.graph.replace import clone_replace, vectorize_node +from pytensor.graph.replace import clone_replace, vectorize_graph from pytensor.graph.type import Type from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row @@ -742,9 +742,9 @@ class TestVectorize: def test_shape(self): vec = tensor(shape=(None,), dtype="float64") mat = tensor(shape=(None, None), dtype="float64") - node = shape(vec).owner + out = shape(vec) - [vect_out] = vectorize_node(node, mat).outputs + vect_out = vectorize_graph(out, {vec: mat}) assert equal_computations( [vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))] ) @@ -758,8 +758,8 @@ def test_shape(self): mat = tensor(shape=(None, None), dtype="float64") tns = tensor(shape=(None, None, None, None), dtype="float64") - node = shape(mat).owner - [vect_out] = vectorize_node(node, tns).outputs + out = shape(mat) + vect_out = vectorize_graph(out, {mat: tns}) assert equal_computations( [vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))] ) @@ -779,11 +779,13 @@ def test_reshape(self): vec = tensor(shape=(None,), dtype="float64") mat = tensor(shape=(None, None), dtype="float64") - shape = (-1, x) - node = reshape(vec, shape).owner + shape = as_tensor_variable([-1, x]) + out = reshape(vec, shape) - [vect_out] = vectorize_node(node, mat, shape).outputs - assert equal_computations([vect_out], [reshape(mat, (*mat.shape[:1], -1, x))]) + vect_out = vectorize_graph(out, {vec: mat}) + utt.assert_equal_computations( + [vect_out], [reshape(mat, (*mat.shape[:1], *stack((-1, x))))] + ) x_test_value = 2 mat_test_value = np.ones((5, 6)) @@ -796,11 +798,11 @@ def test_reshape(self): ) new_shape = (5, -1, x) - [vect_out] = vectorize_node(node, mat, new_shape).outputs - assert equal_computations([vect_out], [reshape(mat, new_shape)]) + vect_out = vectorize_graph(out, {vec: mat, shape: new_shape}) + utt.assert_equal_computations([vect_out], [reshape(mat, new_shape)]) new_shape = stack([[-1, x], [x - 1, -1]], axis=0) - [vect_out] = vectorize_node(node, vec, new_shape).outputs + vect_out = vectorize_graph(out, {shape: new_shape}) vec_test_value = np.arange(6) np.testing.assert_allclose( vect_out.eval({x: 3, vec: vec_test_value}), @@ -811,13 +813,13 @@ def test_reshape(self): ValueError, match="Invalid shape length passed into vectorize node of Reshape", ): - vectorize_node(node, vec, (5, 2, x)) + vectorize_graph(out, {shape: (5, 2, x)}) with pytest.raises( ValueError, match="Invalid shape length passed into vectorize node of Reshape", ): - vectorize_node(node, mat, (5, 3, 2, x)) + vectorize_graph(out, {vec: mat, shape: (5, 3, 2, x)}) def test_specify_shape(self): x = scalar("x", dtype=int) @@ -825,27 +827,9 @@ def test_specify_shape(self): tns = tensor(shape=(None, None, None)) shape = (x, None) - node = specify_shape(mat, shape).owner - vect_node = vectorize_node(node, tns, *shape) - assert equal_computations( - vect_node.outputs, [specify_shape(tns, (None, x, None))] - ) - - new_shape = (5, 2, x) - vect_node = vectorize_node(node, tns, *new_shape) - assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))]) + out = specify_shape(mat, shape) + vect_out = vectorize_graph(out, {mat: tns}) + assert equal_computations([vect_out], [specify_shape(tns, (None, x, None))]) with pytest.raises(NotImplementedError): - vectorize_node(node, mat, *([x, x], None)) - - with pytest.raises( - ValueError, - match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", - ): - vectorize_node(node, mat, *(5, 2, x)) - - with pytest.raises( - ValueError, - match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", - ): - vectorize_node(node, tns, *(5, 3, 2, x)) + vectorize_graph(out, {x: [x, x]}) diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index 24a172860f..54174e025e 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -9,7 +9,7 @@ from pytensor.compile.function import function from pytensor.configdefaults import config -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.special import ( LogSoftmax, @@ -168,18 +168,18 @@ def test_vectorize_softmax(op, constructor, core_axis, batch_axis): x = tensor(shape=(5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5)) - node = constructor(x, axis=core_axis).owner - assert isinstance(node.op, op) + out = constructor(x, axis=core_axis) + assert isinstance(out.owner.op, op) - new_node = vectorize_node(node, batch_x) + new_out = vectorize_graph(out, {x: batch_x}) if len(batch_axis) == 1: - assert isinstance(new_node.op, op) - assert (new_node.op.axis,) == batch_axis + assert isinstance(new_out.owner.op, op) + assert (new_out.owner.op.axis,) == batch_axis else: - assert isinstance(new_node.op, Blockwise) and isinstance( - new_node.op.core_op, op + assert isinstance(new_out.owner.op, Blockwise) and isinstance( + new_out.owner.op.core_op, op ) - assert new_node.op.core_op.axis == core_axis + assert new_out.owner.op.core_op.axis == core_axis def test_poch(): diff --git a/tests/xtensor/test_basic.py b/tests/xtensor/test_basic.py index 7dd5ff578f..927734679e 100644 --- a/tests/xtensor/test_basic.py +++ b/tests/xtensor/test_basic.py @@ -1,8 +1,21 @@ import numpy as np +import pytest + + +pytest.importorskip("xarray") from pytensor import function -from pytensor.xtensor.basic import Rename +from pytensor.graph import vectorize_graph +from pytensor.tensor import matrix, vector +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.basic import ( + Rename, + rename, + tensor_from_xtensor, + xtensor_from_tensor, +) from pytensor.xtensor.type import xtensor +from tests.xtensor.util import check_vectorization, xr_random_like def test_shape_feature_does_not_see_xop(): @@ -24,3 +37,38 @@ def infer_shape(self, node, inputs, outputs): fn = function([x], out) np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0]) assert not CALLED + + +def test_rename_vectorize(): + ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64") + check_vectorization(ab, rename(ab, a="c")) + + +def test_xtensor_from_tensor_vectorize(): + t = vector("t") + x = xtensor_from_tensor(t, dims=("a",)) + + t_batched = matrix("t_batched") + with pytest.raises( + NotImplementedError, match=r"Vectorization of .* not implemented" + ): + vectorize_graph([x], {t: t_batched}) + + +def test_tensor_from_xtensor_vectorize(): + x = xtensor("x", dims=("a",), shape=(3,)) + y = tensor_from_xtensor(x) + + x_val = xr_random_like(x) + x_batched_val = x_val.expand_dims({"batch": 2}) + x_batched = as_xtensor(x_batched_val).type("x_batched") + + [y_batched] = vectorize_graph([y], {x: x_batched}) + + # y_batched should be a Matrix (batch, a) -> (2, 3) + assert y_batched.type.shape == (2, 3) + + fn = function([x_batched], y_batched) + res = fn(x_batched_val) + + np.testing.assert_allclose(res, x_batched_val.values) diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index a292868e72..f3ba5b773c 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -14,6 +14,7 @@ from pytensor.xtensor import xtensor from tests.unittest_tools import assert_equal_computations from tests.xtensor.util import ( + check_vectorization, xr_arange_like, xr_assert_allclose, xr_function, @@ -542,3 +543,35 @@ def test_empty_update_index(): fn = xr_function([x], out1) x_test = xr_random_like(x) xr_assert_allclose(fn(x_test), x_test + 1) + + +def test_indexing_vectorize(): + abc = xtensor(dims=("a", "b", "c"), shape=(3, 5, 7)) + a_idx = xtensor(dims=("a",), shape=(5,), dtype="int64") + c_idx = xtensor(dims=("c",), shape=(3,), dtype="int64") + + abc_val = xr_random_like(abc) + a_idx_val = DataArray([0, 1, 0, 2, 0], dims=("a",)) + c_idx_val = DataArray([0, 5, 6], dims=("c",)) + + check_vectorization([abc, a_idx], [abc.isel(a=a_idx)], [abc_val, a_idx_val]) + check_vectorization( + [abc, a_idx], [abc.isel(a=a_idx.rename(a="b"))], [abc_val, a_idx_val] + ) + check_vectorization( + [abc, a_idx], [abc.isel(a=a_idx.rename(a="d"))], [abc_val, a_idx_val] + ) + check_vectorization([abc, a_idx], [abc.isel(c=a_idx[:3])], [abc_val, a_idx_val]) + check_vectorization( + [abc, a_idx], [abc.isel(a=a_idx, c=a_idx)], [abc_val, a_idx_val] + ) + check_vectorization( + [abc, a_idx, c_idx], + [abc.isel(a=a_idx, c=c_idx)], + [abc_val, a_idx_val, c_idx_val], + ) + + +def test_index_update_vectorize(): + # TODO + pass diff --git a/tests/xtensor/test_linalg.py b/tests/xtensor/test_linalg.py index b365b1a0a8..438c6cf845 100644 --- a/tests/xtensor/test_linalg.py +++ b/tests/xtensor/test_linalg.py @@ -16,7 +16,7 @@ from pytensor.xtensor.linalg import cholesky, solve from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function +from tests.xtensor.util import check_vectorization, xr_assert_allclose, xr_function def test_cholesky(): @@ -74,3 +74,23 @@ def test_solve_matrix_b(): fn(a_test, b_test), xr_solve(a_test, b_test, dims=["country", "city", "district"]), ) + + +def test_linalg_vectorize(): + # Note: We only need to test a couple Ops, since the vectorization logic is not Op specific + + a = xtensor("b", dims=("a",), shape=(3,)) + ab = xtensor("a", dims=("a", "b"), shape=(3, 3)) + test_spd = np.random.randn(3, 3) + test_spd = test_spd @ test_spd.T + + check_vectorization( + [ab], + [cholesky(ab, dims=("b", "a"))], + input_vals=[DataArray(test_spd, dims=("a", "b"))], + ) + + check_vectorization( + [ab, a], + [solve(ab, a, dims=("a", "b"))], + ) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index 3d2755f6ac..4f3714b657 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -17,7 +17,12 @@ from pytensor.xtensor.basic import rename from pytensor.xtensor.math import add, exp, logsumexp from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function +from tests.xtensor.util import ( + check_vectorization, + xr_arange_like, + xr_assert_allclose, + xr_function, +) def test_all_scalar_ops_are_wrapped(): @@ -340,3 +345,11 @@ def test_dot_errors(): match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)", ): fn(x_test, y_test) + + +def test_xelemwise_vectorize(): + ab = xtensor("ab", dims=("a", "b"), shape=(2, 3)) + bc = xtensor("bc", dims=("b", "c"), shape=(3, 5)) + + check_vectorization([ab], [exp(ab)]) + check_vectorization([ab, bc], [ab + bc]) diff --git a/tests/xtensor/test_random.py b/tests/xtensor/test_random.py index 54e07da4a8..ffd5ffb50d 100644 --- a/tests/xtensor/test_random.py +++ b/tests/xtensor/test_random.py @@ -9,6 +9,7 @@ from copy import deepcopy import numpy as np +from xarray import DataArray import pytensor.tensor.random as ptr import pytensor.xtensor.random as pxr @@ -26,6 +27,7 @@ normal, ) from pytensor.xtensor.vectorization import XRV +from tests.xtensor.util import check_vectorization def lower_rewrite(vars): @@ -438,3 +440,27 @@ def test_multivariate_normal(): ): # cov must have both core_dims multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols")) + + +def test_xrv_vectorize(): + # Note: We only need to test a couple Ops, since the vectorization logic is not Op specific + + n = xtensor("n", dims=("n",), shape=(3,), dtype=int) + pna = xtensor("p", dims=("p", "n", "a"), shape=(5, 3, 2)) + out = multinomial(n, pna, core_dims=("p",), extra_dims={"extra": 5}) + check_vectorization( + [n, pna], + [out], + input_vals=[ + DataArray([3, 5, 10], dims=("n",)), + DataArray( + np.random.multinomial(n=1, pvals=np.ones(5) / 5, size=(2, 3)).T, + dims=("p", "n", "a"), + ), + ], + ) + + +def test_xrv_batch_extra_dim_vectorize(): + # TODO: Check it raises NotImplementedError when we try to batch the extra_dim of an xrv + pass diff --git a/tests/xtensor/test_reduction.py b/tests/xtensor/test_reduction.py index cce41a2011..4816d79bb3 100644 --- a/tests/xtensor/test_reduction.py +++ b/tests/xtensor/test_reduction.py @@ -8,7 +8,12 @@ import xarray as xr from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function +from tests.xtensor.util import ( + check_vectorization, + xr_arange_like, + xr_assert_allclose, + xr_function, +) @pytest.mark.parametrize( @@ -99,3 +104,16 @@ def test_discrete_reduction_upcasting(signed): res = fn(x_val) np.testing.assert_allclose(res, [test_val, test_val**2]) xr_assert_allclose(res, x_val.cumprod()) + + +def test_reduction_vectorize(): + # Note: We only need to test a couple Ops, since the vectorization logic is not Op specific + abc = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) + + check_vectorization([abc], [abc.sum(dim="a")]) + check_vectorization([abc], [abc.max(dim=("a", "c"))]) + check_vectorization([abc], [abc.all()]) + + check_vectorization([abc], [abc.cumsum(dim="b")]) + check_vectorization([abc], [abc.cumsum(dim=("c", "b"))]) + check_vectorization([abc], [abc.cumprod()]) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index efc3bfcd10..fa2c9eddf2 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -15,7 +15,8 @@ from xarray import ones_like as xr_ones_like from xarray import zeros_like as xr_zeros_like -from pytensor.tensor import scalar +from pytensor.graph import vectorize_graph +from pytensor.tensor import scalar, vector from pytensor.xtensor.shape import ( broadcast, concat, @@ -25,8 +26,9 @@ unstack, zeros_like, ) -from pytensor.xtensor.type import xtensor +from pytensor.xtensor.type import as_xtensor, xtensor from tests.xtensor.util import ( + check_vectorization, xr_arange_like, xr_assert_allclose, xr_function, @@ -800,3 +802,94 @@ def test_zeros_like(): expected1 = xr_zeros_like(x_test) xr_assert_allclose(result1, expected1) assert result1.dtype == expected1.dtype + + +def test_shape_ops_vectorize(): + a1 = xtensor("a1", dims=("a", "1"), shape=(2, 1), dtype="float64") + ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64") + abc = xtensor("abc", dims=("a", "b", "c"), shape=(2, 3, 5), dtype="float64") + a_bc_d = xtensor("a_bc_d", dims=("a", "bc", "d"), shape=(4, 15, 7)) + + check_vectorization(abc, abc.transpose("b", "c", "a")) + check_vectorization(abc, abc.transpose("b", ...)) + + check_vectorization(abc, stack(abc, new_dim=("a", "c"))) + check_vectorization(a_bc_d, unstack(a_bc_d, bc=dict(b=3, c=5))) + + check_vectorization([abc, ab], concat([abc, ab], dim="a")) + + check_vectorization(a1, a1.squeeze("1")) + + check_vectorization(abc, abc.expand_dims(d=5)) + + check_vectorization([ab, abc], broadcast(ab, abc)) + check_vectorization([ab, abc, a1], broadcast(ab, abc, a1, exclude="1")) + # a is longer in a_bc_d than in ab and abc, helper can't handle that + # check_vectorization([ab, abc, a_bc_d], broadcast(ab, abc, a_bc_d, exclude="a")) + + +def test_broadcast_exclude_vectorize(): + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("b", "c"), shape=(7, 5)) + + # broadcast exclude "b" + out_x, out_y = broadcast(x, y, exclude=("b",)) + + # Manual vectorization check + x_val = xr_random_like(x) + y_val = xr_random_like(y) + + # Vectorize inputs manually to pass to vectorize_graph + x_batch_val = x_val.expand_dims({"batch": 2}) + y_batch_val = y_val.expand_dims({"batch": 2}) + + x_batch = as_xtensor(x_batch_val).type("x_batch") + y_batch = as_xtensor(y_batch_val).type("y_batch") + + [out_x_vec, out_y_vec] = vectorize_graph([out_x, out_y], {x: x_batch, y: y_batch}) + + fn = xr_function([x_batch, y_batch], [out_x_vec, out_y_vec]) + res_x, res_y = fn(x_batch_val, y_batch_val) + + expected_x = [] + expected_y = [] + for i in range(2): + ex_x, ex_y = xr_broadcast( + x_batch_val.isel(batch=i), y_batch_val.isel(batch=i), exclude=("b",) + ) + expected_x.append(ex_x) + expected_y.append(ex_y) + + expected_x = xr_concat(expected_x, dim="batch") + expected_y = xr_concat(expected_y, dim="batch") + + xr_assert_allclose(res_x, expected_x) + xr_assert_allclose(res_y, expected_y) + + +def test_expand_dims_batch_length_vectorize(): + x = xtensor("x", dims=("a",), shape=(3,)) + l = scalar("l", dtype="int64") + y = x.expand_dims(b=l) + + x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch") + l_batch = vector("l_batch", dtype="int64") + + with pytest.raises( + NotImplementedError, match=r"Vectorization of .* not implemented" + ): + vectorize_graph([y], {x: x_batch, l: l_batch}) + + +def test_unstack_batch_length_vectorize(): + x = xtensor("x", dims=("ab",), shape=(12,)) + l = scalar("l", dtype="int64") + y = unstack(x, ab={"a": l, "b": x.sizes["ab"] // l}) + + x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch") + l_batch = vector("l_batch", dtype="int64") + + with pytest.raises( + NotImplementedError, match=r"Vectorization of .* not implemented" + ): + vectorize_graph([y], {x: x_batch, l: l_batch}) diff --git a/tests/xtensor/test_signal.py b/tests/xtensor/test_signal.py index 5984397d6f..ef35fcae00 100644 --- a/tests/xtensor/test_signal.py +++ b/tests/xtensor/test_signal.py @@ -12,7 +12,12 @@ from pytensor.xtensor.signal import convolve1d from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function +from tests.xtensor.util import ( + check_vectorization, + xr_arange_like, + xr_assert_allclose, + xr_function, +) @pytest.mark.parametrize("mode", ("full", "valid", "same")) @@ -68,3 +73,14 @@ def test_convolve_1d_invalid(): match=re.escape("Input 1 has invalid core dims ['time']. Allowed: ('kernel',)"), ): convolve1d(in1, in2.rename({"batch": "time"}), dims=("time", "kernel")) + + +def test_signal_vectorize(): + # Note: We only need to test a couple Ops, since the vectorization logic is not Op specific + ab = xtensor("a", dims=("a", "b"), shape=(3, 3)) + c = xtensor(name="c", dims=("c",), shape=(7,)) + + check_vectorization( + [ab, c], + [convolve1d(ab, c, dims=("a", "c"))], + ) diff --git a/tests/xtensor/test_vectorization.py b/tests/xtensor/test_vectorization.py new file mode 100644 index 0000000000..2042619f10 --- /dev/null +++ b/tests/xtensor/test_vectorization.py @@ -0,0 +1,114 @@ +import numpy as np + +from pytensor.tensor.type import tensor +from pytensor.xtensor.basic import xtensor_from_tensor +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor +from pytensor.xtensor.vectorization import vectorize_graph +from tests.unittest_tools import assert_equal_computations + + +class TestVectorizeGraph: + def test_pure_xtensor_graph(self): + x = xtensor("x", dims=("a",)) + out = x + 1 + + x_new = xtensor("x_new", dims=("c", "a", "b")) + [out_vec] = vectorize_graph([out], {x: x_new}) + + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("c", "b", "a") + expected = x_new.transpose("c", "b", "a") + 1 + assert_equal_computations([out_vec], [expected]) + + def test_intermediate_tensor_graph(self): + x = xtensor("x", dims=("a",)) + t = x.values # Convert to TensorVariable + t2 = t + np.ones(1) + out = xtensor_from_tensor(t2, dims=("a",)) + + x_new = xtensor("x_new", dims=("a", "b")) + [out_vec] = vectorize_graph([out], {x: x_new}) + + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("b", "a") + expected = as_xtensor( + x_new.transpose("b", "a").values + np.ones(1), dims=("b", "a") + ) + assert_equal_computations([out_vec], [expected]) + + def test_intermediate_tensor_multiple_inputs_graph(self): + x = xtensor("x", dims=("a",)) + y = xtensor("y", dims=("a",)) + t = x.values + y.values + out = xtensor_from_tensor(t, dims=("a",)) + + x_new = xtensor("x_new", dims=("a", "c")) + + # Both inputs have the same batch dims + y_new = xtensor("y_new", dims=("c", "a")) + [out_vec] = vectorize_graph([out], {x: x_new, y: y_new}) + + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("c", "a") + expected = as_xtensor( + (x_new.transpose("c", "a").values + y_new.transpose("c", "a").values), + dims=("c", "a"), + ) + assert_equal_computations([out_vec], [expected]) + + # Inputs have different batch dims + y_new = xtensor("y_new", dims=("b", "a")) + [out_vec] = vectorize_graph([out], {x: x_new, y: y_new}) + + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("c", "b", "a") + expected = as_xtensor( + ( + x_new.transpose("c", "a").values[:, None] + + y_new.transpose("b", "a").values[None, :] + ), + dims=("c", "b", "a"), + ) + assert_equal_computations([out_vec], [expected]) + + def test_mixed_type_inputs(self): + x = xtensor("x", dims=("a",), shape=(3,)) + y = tensor("y", shape=(5,)) + + out = as_xtensor(y[2:], dims=("b",)) + x + + x_new = xtensor("x_new", dims=("a", "d"), shape=(3, 7)) + y_new = tensor("y_new", shape=(7, 5)) + + # Case where the new dimension of y is aligned with the new dimension of x + [out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["d"]) + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("d", "b", "a") + expected = as_xtensor(y_new[:, 2:], dims=("d", "b")) + x_new.transpose("d", "a") + assert_equal_computations([out_vec], [expected]) + return + + # New dimension of y is distinct from that of x + [out_vec] = vectorize_graph([out], {x: x_new, y: y_new}, new_tensor_dims=["c"]) + assert isinstance(out_vec.type, XTensorType) + assert out_vec.type.dims == ("d", "c", "b", "a") + return + expected = as_xtensor(y_new, dims=("c", "b")) + x_new.transpose("d", "a") + assert_equal_computations([out_vec], [expected]) + + def test_pure_tensor_graph(self): + # TODO + pass + + def test_intermediate_xtensor_graph(self): + # TODO: Inputs and outputs all tensor, intermediate xtensor graph + pass + + def test_invalid_cases(self): + """TODO + - missing xtensor dims + - new xtensor dims that were present in original graph + - missing tensor dims + - missing new_tensor_dims + - unused new_tensor_dims + """ diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py index 82d1f071c8..19ba087699 100644 --- a/tests/xtensor/util.py +++ b/tests/xtensor/util.py @@ -1,14 +1,18 @@ +from itertools import chain + import pytest +from pytensor.graph import vectorize_graph + -pytest.importorskip("xarray") +xr = pytest.importorskip("xarray") import numpy as np from xarray import DataArray from xarray.testing import assert_allclose from pytensor import function -from pytensor.xtensor.type import XTensorType +from pytensor.xtensor.type import XTensorType, as_xtensor def xr_function(*args, **kwargs): @@ -76,3 +80,57 @@ def xr_random_like(x, rng=None): return DataArray( rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims ) + + +def check_vectorization(inputs, outputs, input_vals=None, rng=None): + # Create core graph and function + if not isinstance(inputs, list | tuple): + inputs = (inputs,) + + if not isinstance(outputs, list | tuple): + outputs = (outputs,) + + # apply_ufunc isn't happy with list output or single entry + _core_fn = function(inputs, outputs) + + def core_fn(*args, _core_fn=_core_fn): + res = _core_fn(*args) + if len(res) == 1: + return res[0] + else: + return tuple(res) + + if input_vals is None: + rng = np.random.default_rng(rng) + input_vals = [xr_random_like(inp, rng) for inp in inputs] + + # Create vectorized inputs + batch_inputs = [] + batch_input_vals = [] + for i, (inp, val) in enumerate(zip(inputs, input_vals)): + new_val = val.expand_dims({f"batch_{i}": 2 ** (i + 1)}) + new_inp = as_xtensor(new_val).type("batch_inp") + batch_inputs.append(new_inp) + batch_input_vals.append(new_val) + + # Create vectorized function + new_outputs = vectorize_graph(outputs, dict(zip(inputs, batch_inputs))) + vec_fn = xr_function(batch_inputs, new_outputs) + vec_res = vec_fn(*batch_input_vals) + + # xarray.apply_ufunc with vectorize=True loops over non-core dims + input_core_dims = [i.dims for i in inputs] + output_core_dims = [o.dims for o in outputs] + expected_res = xr.apply_ufunc( + core_fn, + *batch_input_vals, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + exclude_dims=set(chain.from_iterable((*input_core_dims, *output_core_dims))), + vectorize=True, + ) + if not isinstance(expected_res, list | tuple): + expected_res = (expected_res,) + + for v_r, e_r in zip(vec_res, expected_res): + xr_assert_allclose(v_r, e_r.transpose(*v_r.dims))