diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 98b541083a..7fcb18f1f2 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,4 +1,3 @@ -import typing import warnings from types import EllipsisType @@ -11,6 +10,7 @@ ) from pytensor.scalar import ScalarType from pytensor.tensor import ( + TensorLike, TensorType, _as_tensor_variable, as_tensor_variable, @@ -20,14 +20,16 @@ try: - import xarray as xr + import xarray + DataArray = xarray.DataArray XARRAY_AVAILABLE = True except ModuleNotFoundError: XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeAlias, TypeVar, Union +from typing import cast as typing_cast import numpy as np @@ -95,7 +97,7 @@ def clone( def filter(self, value, strict=False, allow_downcast=None): # XTensorType behaves like TensorType at runtime, so we filter the same way. - if XARRAY_AVAILABLE and isinstance(value, xr.DataArray): + if XARRAY_AVAILABLE and isinstance(value, DataArray): value = value.transpose(*self.dims).values return TensorType.filter( self, value, strict=strict, allow_downcast=allow_downcast @@ -109,7 +111,7 @@ def filter_variable(self, other, allow_convert=True): if not isinstance(other, Variable): # The value is not a Variable: we cast it into # a Constant of the appropriate Type. - if XARRAY_AVAILABLE and isinstance(other, xr.DataArray): + if XARRAY_AVAILABLE and isinstance(other, DataArray): other = other.transpose(*self.dims).values other = XTensorConstant(type=self, data=other) @@ -369,7 +371,7 @@ def __trunc__(self): @property def values(self) -> TensorVariable: """Convert to a TensorVariable with the same data.""" - return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self)) + return typing_cast(TensorVariable, px.basic.tensor_from_xtensor(self)) # Can't provide property data because that's already taken by Constants! # data = values @@ -409,7 +411,7 @@ def shape(self) -> tuple[TensorVariable, ...]: @property def size(self) -> TensorVariable: """The total number of elements in the variable.""" - return typing.cast(TensorVariable, variadic_mul(*self.shape)) + return typing_cast(TensorVariable, variadic_mul(*self.shape)) @property def dtype(self) -> str: @@ -904,6 +906,12 @@ def broadcast_like(self, other, exclude=None): return self_bcast +if XARRAY_AVAILABLE: + XTensorLike: TypeAlias = Union[TensorLike, XTensorVariable, "DataArray"] +else: + XTensorLike: TypeAlias = TensorLike | XTensorVariable + + class XTensorConstantSignature(TensorConstantSignature): pass @@ -939,18 +947,18 @@ def _extract_data_and_dims( x, dims: None | Sequence[str] = None ) -> tuple[np.ndarray, tuple[str, ...]]: x_dims: tuple[str, ...] - if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): + if XARRAY_AVAILABLE and isinstance(x, DataArray): xarray_dims = x.dims if not all(isinstance(dim, str) for dim in xarray_dims): raise NotImplementedError( "DataArray can only be converted to xtensor if all dims are of string type" ) - x_dims = tuple(typing.cast(typing.Iterable[str], xarray_dims)) + x_dims = tuple(typing_cast(Sequence[str], xarray_dims)) x_data = x.values if dims is not None and dims != x_dims: raise ValueError( - f"xr.DataArray dims {x_dims} don't match requested specified {dims}. " + f"xarray.DataArray dims {x_dims} don't match requested specified {dims}. " "Use transpose or rename" ) else: @@ -1015,8 +1023,8 @@ def xtensor_shared( if XARRAY_AVAILABLE: - _as_symbolic.register(xr.DataArray, xtensor_constant) - shared_constructor.register(xr.DataArray, xtensor_shared) + _as_symbolic.register(DataArray, xtensor_constant) + shared_constructor.register(DataArray, xtensor_shared) def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None):