Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pytensor.compile import (
DeepCopyOp,
SharedVariable,
ViewOp,
register_deep_copy_op_c_code,
register_view_op_c_code,
Expand Down Expand Up @@ -32,6 +33,7 @@

import pytensor.xtensor as px
from pytensor import _as_symbolic, config
from pytensor.compile.sharedvalue import shared_constructor
from pytensor.graph import Apply, Constant
from pytensor.graph.basic import OptionalApplyType, Variable
from pytensor.graph.type import HasDataType, HasShape, Type
Expand Down Expand Up @@ -93,6 +95,8 @@ def clone(

def filter(self, value, strict=False, allow_downcast=None):
# XTensorType behaves like TensorType at runtime, so we filter the same way.
if XARRAY_AVAILABLE and isinstance(value, xr.DataArray):
value = value.transpose(*self.dims).values
return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast
)
Expand All @@ -105,6 +109,8 @@ def filter_variable(self, other, allow_convert=True):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
if XARRAY_AVAILABLE and isinstance(other, xr.DataArray):
other = other.transpose(*self.dims).values
other = XTensorConstant(type=self, data=other)

if self.is_super(other.type):
Expand Down Expand Up @@ -929,15 +935,15 @@ def signature(self):
XTensorType.constant_type = XTensorConstant # type: ignore


def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""

def _extract_data_and_dims(
x, dims: None | Sequence[str] = None
) -> tuple[np.ndarray, tuple[str, ...]]:
x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError(
"DataArray can only be converted to xtensor_constant if all dims are of string type"
"DataArray can only be converted to xtensor if all dims are of string type"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't auto-generate dim__0 dim__1 etc like arviz?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I opted to be more picky so you don't accidentally end up broadcasting things / misaligning variables

)
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
x_data = x.values
Expand All @@ -958,6 +964,13 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError(
"Cannot convert TensorLike constant to XTensorConstant without specifying dims."
)
return x_data, x_dims


def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
"""Convert a constant value to an XTensorConstant."""
x_data, x_dims = _extract_data_and_dims(x, dims)

try:
return XTensorConstant(
XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape),
Expand All @@ -968,11 +981,42 @@ def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
raise TypeError(f"Could not convert {x} to XTensorType")


if XARRAY_AVAILABLE:
class XTensorSharedVariable(SharedVariable, XTensorVariable):
"""Shared variable of XTensorType."""

@_as_symbolic.register(xr.DataArray)
def as_symbolic_xarray(x, **kwargs):
return xtensor_constant(x, **kwargs)

def xtensor_shared(
x,
*,
name=None,
shape=None,
dims=None,
strict=False,
allow_downcast=None,
borrow=False,
):
r"""`SharedVariable` constructor for `XTensorType`\s.

Notes
-----
The default is to assume that the `shape` value might be resized in any
dimension, so the default shape is ``(None,) * len(value.shape)``. The
optional `shape` argument will override this default.
"""
x_data, x_dims = _extract_data_and_dims(x, dims)

return XTensorSharedVariable(
type=XTensorType(dtype=x_data.dtype, dims=x_dims, shape=shape),
value=x_data if borrow else x_data.copy(),
strict=strict,
allow_downcast=allow_downcast,
name=name if name is not None else getattr(x, "name", None),
)


if XARRAY_AVAILABLE:
_as_symbolic.register(xr.DataArray, xtensor_constant)
shared_constructor.register(xr.DataArray, xtensor_shared)


def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
Expand Down
97 changes: 80 additions & 17 deletions tests/xtensor/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@

import pytest

from pytensor import as_symbolic, shared
from pytensor.compile import SharedVariable


pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")

import numpy as np
from xarray import DataArray

from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorConstant, XTensorType, as_xtensor
from pytensor.xtensor.type import (
XTensorConstant,
XTensorSharedVariable,
XTensorType,
as_xtensor,
xtensor_constant,
xtensor_shared,
)


def test_xtensortype():
Expand Down Expand Up @@ -110,23 +120,76 @@ def test_xtensortype_filter_variable_constant():
assert isinstance(res, XTensorConstant) and res.type == x.type


def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
@pytest.mark.parametrize(
"constant_constructor", (as_symbolic, as_xtensor, xtensor_constant)
)
def test_xtensor_constant(constant_constructor):
x = constant_constructor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert isinstance(x, Constant)
assert isinstance(x, XTensorConstant)
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))

y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)

z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
if constant_constructor is not as_symbolic:
# We should be able to pass numpy arrays if we pass dims
y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)

z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)


@pytest.mark.parametrize("shared_constructor", (shared, xtensor_shared))
def test_xtensor_shared(shared_constructor):
arr = np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")
xarr = DataArray(arr, dims=("a", "b"), name="xarr")
shared_xarr = shared_constructor(xarr)
assert isinstance(shared_xarr, SharedVariable)
assert isinstance(shared_xarr, XTensorSharedVariable)
assert shared_xarr.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(None, None)
)
assert xarr.name == "xarr"

shared_rrax = shared_constructor(xarr, shape=(2, None), name="rrax")
assert isinstance(shared_rrax, XTensorSharedVariable)
assert shared_rrax.type == XTensorType(
dtype="int64", dims=("a", "b"), shape=(2, None)
)
assert shared_rrax.name == "rrax"

if shared_constructor == xtensor_shared:
# We should be able to pass numpy arrays, if we pass dims
with pytest.raises(TypeError):
shared_constructor(arr)
shared_arr = shared_constructor(arr, dims=("a", "b"))
assert isinstance(shared_arr, XTensorSharedVariable)
assert shared_arr.type == shared_xarr.type

# Test get and set_value
retrieved_value = shared_xarr.get_value()
assert isinstance(retrieved_value, np.ndarray)
np.testing.assert_allclose(retrieved_value, xarr.to_numpy())

shared_xarr.set_value(xarr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())

# Test dims in different order
shared_xarr.set_value(xarr[::-1].T)
np.testing.assert_allclose(shared_xarr.get_value(), xarr[::-1].to_numpy())

with pytest.raises(ValueError):
shared_xarr.set_value(xarr.rename(b="c"))

shared_xarr.set_value(arr[::-1])
np.testing.assert_allclose(shared_xarr.get_value(), arr[::-1])


def test_as_tensor():
Expand Down
Loading