From 65869684c5d6a1fca951823a68dc221c9ae371cd Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 8 Feb 2026 02:01:16 +0100 Subject: [PATCH] Implement shared XTensorVariables --- pytensor/xtensor/type.py | 60 +++++++++++++++++++---- tests/xtensor/test_type.py | 97 +++++++++++++++++++++++++++++++------- 2 files changed, 132 insertions(+), 25 deletions(-) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 18672ac83d..98b541083a 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -4,6 +4,7 @@ from pytensor.compile import ( DeepCopyOp, + SharedVariable, ViewOp, register_deep_copy_op_c_code, register_view_op_c_code, @@ -32,6 +33,7 @@ import pytensor.xtensor as px from pytensor import _as_symbolic, config +from pytensor.compile.sharedvalue import shared_constructor from pytensor.graph import Apply, Constant from pytensor.graph.basic import OptionalApplyType, Variable from pytensor.graph.type import HasDataType, HasShape, Type @@ -93,6 +95,8 @@ def clone( def filter(self, value, strict=False, allow_downcast=None): # XTensorType behaves like TensorType at runtime, so we filter the same way. + if XARRAY_AVAILABLE and isinstance(value, xr.DataArray): + value = value.transpose(*self.dims).values return TensorType.filter( self, value, strict=strict, allow_downcast=allow_downcast ) @@ -105,6 +109,8 @@ def filter_variable(self, other, allow_convert=True): if not isinstance(other, Variable): # The value is not a Variable: we cast it into # a Constant of the appropriate Type. + if XARRAY_AVAILABLE and isinstance(other, xr.DataArray): + other = other.transpose(*self.dims).values other = XTensorConstant(type=self, data=other) if self.is_super(other.type): @@ -929,15 +935,15 @@ def signature(self): XTensorType.constant_type = XTensorConstant # type: ignore -def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): - """Convert a constant value to an XTensorConstant.""" - +def _extract_data_and_dims( + x, dims: None | Sequence[str] = None +) -> tuple[np.ndarray, tuple[str, ...]]: x_dims: tuple[str, ...] if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): xarray_dims = x.dims if not all(isinstance(dim, str) for dim in xarray_dims): raise NotImplementedError( - "DataArray can only be converted to xtensor_constant if all dims are of string type" + "DataArray can only be converted to xtensor if all dims are of string type" ) x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims)) x_data = x.values @@ -958,6 +964,13 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): raise TypeError( "Cannot convert TensorLike constant to XTensorConstant without specifying dims." ) + return x_data, x_dims + + +def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): + """Convert a constant value to an XTensorConstant.""" + x_data, x_dims = _extract_data_and_dims(x, dims) + try: return XTensorConstant( XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape), @@ -968,11 +981,42 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): raise TypeError(f"Could not convert {x} to XTensorType") -if XARRAY_AVAILABLE: +class XTensorSharedVariable(SharedVariable, XTensorVariable): + """Shared variable of XTensorType.""" - @_as_symbolic.register(xr.DataArray) - def as_symbolic_xarray(x, **kwargs): - return xtensor_constant(x, **kwargs) + +def xtensor_shared( + x, + *, + name=None, + shape=None, + dims=None, + strict=False, + allow_downcast=None, + borrow=False, +): + r"""`SharedVariable` constructor for `XTensorType`\s. + + Notes + ----- + The default is to assume that the `shape` value might be resized in any + dimension, so the default shape is ``(None,) * len(value.shape)``. The + optional `shape` argument will override this default. + """ + x_data, x_dims = _extract_data_and_dims(x, dims) + + return XTensorSharedVariable( + type=XTensorType(dtype=x_data.dtype, dims=x_dims, shape=shape), + value=x_data if borrow else x_data.copy(), + strict=strict, + allow_downcast=allow_downcast, + name=name if name is not None else getattr(x, "name", None), + ) + + +if XARRAY_AVAILABLE: + _as_symbolic.register(xr.DataArray, xtensor_constant) + shared_constructor.register(xr.DataArray, xtensor_shared) def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None): diff --git a/tests/xtensor/test_type.py b/tests/xtensor/test_type.py index fcb3b6bf5f..4414b30fcb 100644 --- a/tests/xtensor/test_type.py +++ b/tests/xtensor/test_type.py @@ -2,6 +2,9 @@ import pytest +from pytensor import as_symbolic, shared +from pytensor.compile import SharedVariable + pytest.importorskip("xarray") pytestmark = pytest.mark.filterwarnings("error") @@ -9,10 +12,17 @@ import numpy as np from xarray import DataArray -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import Constant, equal_computations from pytensor.tensor import as_tensor, specify_shape, tensor from pytensor.xtensor import xtensor -from pytensor.xtensor.type import XTensorConstant, XTensorType, as_xtensor +from pytensor.xtensor.type import ( + XTensorConstant, + XTensorSharedVariable, + XTensorType, + as_xtensor, + xtensor_constant, + xtensor_shared, +) def test_xtensortype(): @@ -110,23 +120,76 @@ def test_xtensortype_filter_variable_constant(): assert isinstance(res, XTensorConstant) and res.type == x.type -def test_xtensor_constant(): - x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b"))) +@pytest.mark.parametrize( + "constant_constructor", (as_symbolic, as_xtensor, xtensor_constant) +) +def test_xtensor_constant(constant_constructor): + x = constant_constructor(DataArray(np.ones((2, 3)), dims=("a", "b"))) + assert isinstance(x, Constant) + assert isinstance(x, XTensorConstant) assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3)) - y = as_xtensor(np.ones((2, 3)), dims=("a", "b")) - assert y.type == x.type - assert x.signature() == y.signature() - assert x.equals(y) - x_eval = x.eval() - assert isinstance(x.eval(), np.ndarray) - np.testing.assert_array_equal(x_eval, y.eval(), strict=True) - - z = as_xtensor(np.ones((3, 2)), dims=("b", "a")) - assert z.type != x.type - assert z.signature() != x.signature() - assert not x.equals(z) - np.testing.assert_array_equal(x_eval, z.eval().T, strict=True) + if constant_constructor is not as_symbolic: + # We should be able to pass numpy arrays if we pass dims + y = as_xtensor(np.ones((2, 3)), dims=("a", "b")) + assert y.type == x.type + assert x.signature() == y.signature() + assert x.equals(y) + x_eval = x.eval() + assert isinstance(x.eval(), np.ndarray) + np.testing.assert_array_equal(x_eval, y.eval(), strict=True) + + z = as_xtensor(np.ones((3, 2)), dims=("b", "a")) + assert z.type != x.type + assert z.signature() != x.signature() + assert not x.equals(z) + np.testing.assert_array_equal(x_eval, z.eval().T, strict=True) + + +@pytest.mark.parametrize("shared_constructor", (shared, xtensor_shared)) +def test_xtensor_shared(shared_constructor): + arr = np.array([[1, 2, 3], [4, 5, 6]], dtype="int64") + xarr = DataArray(arr, dims=("a", "b"), name="xarr") + shared_xarr = shared_constructor(xarr) + assert isinstance(shared_xarr, SharedVariable) + assert isinstance(shared_xarr, XTensorSharedVariable) + assert shared_xarr.type == XTensorType( + dtype="int64", dims=("a", "b"), shape=(None, None) + ) + assert xarr.name == "xarr" + + shared_rrax = shared_constructor(xarr, shape=(2, None), name="rrax") + assert isinstance(shared_rrax, XTensorSharedVariable) + assert shared_rrax.type == XTensorType( + dtype="int64", dims=("a", "b"), shape=(2, None) + ) + assert shared_rrax.name == "rrax" + + if shared_constructor == xtensor_shared: + # We should be able to pass numpy arrays, if we pass dims + with pytest.raises(TypeError): + shared_constructor(arr) + shared_arr = shared_constructor(arr, dims=("a", "b")) + assert isinstance(shared_arr, XTensorSharedVariable) + assert shared_arr.type == shared_xarr.type + + # Test get and set_value + retrieved_value = shared_xarr.get_value() + assert isinstance(retrieved_value, np.ndarray) + np.testing.assert_allclose(retrieved_value, xarr.to_numpy()) + + shared_xarr.set_value(xarr[::-1]) + np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy()) + + # Test dims in different order + shared_xarr.set_value(xarr[::-1].T) + np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy()) + + with pytest.raises(ValueError): + shared_xarr.set_value(xarr.rename(b="c")) + + shared_xarr.set_value(arr[::-1]) + np.testing.assert_allclose(shared_xarr.get_value(), arr[::-1]) def test_as_tensor():