Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import typing
import warnings
from types import EllipsisType

Expand All @@ -11,6 +10,7 @@
)
from pytensor.scalar import ScalarType
from pytensor.tensor import (
TensorLike,
TensorType,
_as_tensor_variable,
as_tensor_variable,
Expand All @@ -20,14 +20,16 @@


try:
import xarray as xr
import xarray

DataArray = xarray.DataArray
XARRAY_AVAILABLE = True
except ModuleNotFoundError:
XARRAY_AVAILABLE = False

from collections.abc import Sequence
from typing import Any, Literal, TypeVar
from typing import Any, Literal, TypeAlias, TypeVar, Union
from typing import cast as typing_cast

import numpy as np

Expand Down Expand Up @@ -95,7 +97,7 @@ def clone(

def filter(self, value, strict=False, allow_downcast=None):
# XTensorType behaves like TensorType at runtime, so we filter the same way.
if XARRAY_AVAILABLE and isinstance(value, xr.DataArray):
if XARRAY_AVAILABLE and isinstance(value, DataArray):
value = value.transpose(*self.dims).values
return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast
Expand All @@ -109,7 +111,7 @@ def filter_variable(self, other, allow_convert=True):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
if XARRAY_AVAILABLE and isinstance(other, xr.DataArray):
if XARRAY_AVAILABLE and isinstance(other, DataArray):
other = other.transpose(*self.dims).values
other = XTensorConstant(type=self, data=other)

Expand Down Expand Up @@ -369,7 +371,7 @@ def __trunc__(self):
@property
def values(self) -> TensorVariable:
"""Convert to a TensorVariable with the same data."""
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
return typing_cast(TensorVariable, px.basic.tensor_from_xtensor(self))

# Can't provide property data because that's already taken by Constants!
# data = values
Expand Down Expand Up @@ -409,7 +411,7 @@ def shape(self) -> tuple[TensorVariable, ...]:
@property
def size(self) -> TensorVariable:
"""The total number of elements in the variable."""
return typing.cast(TensorVariable, variadic_mul(*self.shape))
return typing_cast(TensorVariable, variadic_mul(*self.shape))

@property
def dtype(self) -> str:
Expand Down Expand Up @@ -904,6 +906,12 @@ def broadcast_like(self, other, exclude=None):
return self_bcast


if XARRAY_AVAILABLE:
XTensorLike: TypeAlias = Union[TensorLike, XTensorVariable, "DataArray"]
else:
XTensorLike: TypeAlias = TensorLike | XTensorVariable
Comment on lines +909 to +912
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this conditional type hint, but I wanted to keep xarray optional



class XTensorConstantSignature(TensorConstantSignature):
pass

Expand Down Expand Up @@ -939,18 +947,18 @@ def _extract_data_and_dims(
x, dims: None | Sequence[str] = None
) -> tuple[np.ndarray, tuple[str, ...]]:
x_dims: tuple[str, ...]
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
if XARRAY_AVAILABLE and isinstance(x, DataArray):
xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError(
"DataArray can only be converted to xtensor if all dims are of string type"
)
x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims))
x_dims = tuple(typing_cast(Sequence[str], xarray_dims))
x_data = x.values

if dims is not None and dims != x_dims:
raise ValueError(
f"xr.DataArray dims {x_dims} don't match requested specified {dims}. "
f"xarray.DataArray dims {x_dims} don't match requested specified {dims}. "
"Use transpose or rename"
)
else:
Expand Down Expand Up @@ -1015,8 +1023,8 @@ def xtensor_shared(


if XARRAY_AVAILABLE:
_as_symbolic.register(xr.DataArray, xtensor_constant)
shared_constructor.register(xr.DataArray, xtensor_shared)
_as_symbolic.register(DataArray, xtensor_constant)
shared_constructor.register(DataArray, xtensor_shared)


def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):
Expand Down
Loading