Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"]
"tests/xtensor/**/*.py" = ["E402"]



Expand Down
32 changes: 22 additions & 10 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,13 @@ def toposort_key(


@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)


def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)
return op.make_node(*batched_inputs).outputs


@overload
Expand Down Expand Up @@ -289,19 +283,37 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])]

"""
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if isinstance(outputs, Sequence):
seq_outputs = outputs
else:
seq_outputs = [outputs]

if not all(
isinstance(key, Variable) and isinstance(value, Variable)
for key, value in replace.items()
):
raise ValueError(f"Some of the replaced items are not Variables: {replace}")

inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]

vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs)
if isinstance(vect_node_or_outputs, Apply):
# Compatibility with the old API
vect_outputs = vect_node_or_outputs.outputs
else:
vect_outputs = vect_node_or_outputs
for output, vect_output in zip(node.outputs, vect_outputs, strict=True):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
Expand Down Expand Up @@ -38,7 +38,7 @@ def local_useless_blockwise(fgraph, node):
op = node.op
inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs)
vect_node = vectorize_node(dummy_core_node, *inputs)
vect_node = _vectorize_node(dummy_core_node.op, dummy_core_node, *inputs)
if not isinstance(vect_node.op, Blockwise):
return copy_stack_trace(node.outputs, vect_node.outputs)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node):
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.

Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
Note: The reason we don't apply this rewrite eagerly in the `_vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites

such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
Expand Down
34 changes: 34 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def perform(self, node, inputs, outputs):
def do_constant_folding(self, fgraph, node):
return False

def vectorize_node(self, node, *new_inputs, new_dim: str | None):
raise NotImplementedError(f"Vectorized node not implemented for {self}")


class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
Expand All @@ -27,6 +30,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes

def vectorize_node(self, node, *new_inputs, new_dim: str | None):
raise NotImplementedError(f"Vectorized node not implemented for {self}")


class TensorFromXTensor(XTypeCastOp):
__props__ = ()
Expand All @@ -42,6 +48,12 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
# We transpose batch dims to the left, for consistency with tensor vectorization
new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)]


tensor_from_xtensor = TensorFromXTensor()

Expand All @@ -63,6 +75,18 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
if new_x.ndim != old_x.ndim:
if new_dim is None:
raise NotImplementedError(
f"Vectorization of {self} is not well defined because it can't infer the new dimension labels. "
f"Use pytensor.xtensor.vectorization.vectorize_graph instead."
)
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
else:
return [self(new_x)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
Expand All @@ -85,6 +109,16 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))

# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
new_dims = tuple(
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
)
return [type(self)(new_dims)(new_x)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
Expand Down
27 changes: 27 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from itertools import chain
from typing import Literal

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.shape import broadcast
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


Expand Down Expand Up @@ -195,6 +197,15 @@ def combine_dim_info(idx_dim, idx_dim_shape):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])

def vectorize_node(self, node, new_x, *new_idxs, new_dim):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
old_x, *_ = node.inputs
dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=True))
new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims)
return [self(new_x, *new_idxs)]


index = Index()

Expand Down Expand Up @@ -226,6 +237,22 @@ def make_node(self, x, y, *idxs):
out = x.type()
return Apply(self, [x, y, *idxs], [out])

def vectorize_node(self, node, *new_inputs):
# If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set(
chain.from_iterable(old_inp.dims for old_inp in node.inputs)
)
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True):
# Note: This check may be too conservative
if invalid_new_dims := ((set(new_inp.dims) - set(old_inp.dims)) & exclude):
raise NotImplementedError(
f"Vectorize of {self} is undefined because one of the inputs {new_inp} new dimensions "
f"was present in the old inputs: {sorted(invalid_new_dims)}"
)
new_x, *_ = broadcast(*new_inputs, exclude=tuple(exclude))
_, new_y, *new_idxs = new_inputs
return self.make_node(new_x, new_y, *new_idxs)


index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc")
6 changes: 6 additions & 0 deletions pytensor/xtensor/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def make_node(self, x):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
Expand Down Expand Up @@ -117,6 +120,9 @@ def make_node(self, x):
out = x.type()
return Apply(self, [x], [out])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x)
Expand Down
43 changes: 43 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def make_node(self, x):
)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
if dim is not None:
Expand Down Expand Up @@ -146,6 +149,14 @@ def make_node(self, x, *unstacked_length):
)
return Apply(self, [x, *unstacked_lengths], [output])

def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError(
f"Vectorization of {self} with batched unstacked_length not implemented, "
)
return [self(new_x, *new_unstacked_length)]


def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
Expand Down Expand Up @@ -189,6 +200,11 @@ def make_node(self, x):
)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]


def transpose(
x,
Expand Down Expand Up @@ -302,6 +318,9 @@ def make_node(self, *inputs):
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])

def vectorize_node(self, node, *new_inputs, new_dim):
return [self(*new_inputs)]


def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
Expand Down Expand Up @@ -383,6 +402,9 @@ def make_node(self, x):
)
return Apply(self, [x], [out])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable."""
Expand Down Expand Up @@ -442,6 +464,14 @@ def make_node(self, x, size):
)
return Apply(self, [x, size], [out])

def vectorize_node(self, node, new_x, new_size, new_dim):
new_size = new_size.squeeze()
if new_size.type.ndim != 0:
raise NotImplementedError(
f"Vectorization of {self} with batched new_size not implemented, "
)
return self.make_node(new_x, new_size)


def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
Expand Down Expand Up @@ -537,6 +567,19 @@ def make_node(self, *inputs):

return Apply(self, inputs, outputs)

def vectorize_node(self, node, *new_inputs, new_dim):
if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := (
(set(new_x.dims) - set(old_x.dims)) & exclude_set
):
raise NotImplementedError(
f"Vectorize of {self} is undefined because one of the inputs {new_x} "
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
)

return self(*new_inputs, return_list=True)


def broadcast(
*args, exclude: str | Sequence[str] | None = None
Expand Down
2 changes: 1 addition & 1 deletion pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None)

if isinstance(x, Variable):
if isinstance(x.type, XTensorType):
if (dims is None) or (x.type.dims == dims):
if (dims is None) or (x.type.dims == tuple(dims)):
return x
else:
raise ValueError(
Expand Down
Loading
Loading