diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index fed23e73364..745340822d1 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,3 +1,5 @@ +from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format @@ -6,34 +8,30 @@ from ._mask import Mask from ._torch_function_helpers import set_return_type from ._tv_tensor import TVTensor +from torchvision.tv_tensors._tv_tensor import TVTensor + from ._video import Video +TVTensorType = TypeVar("TVTensorType", bound=TVTensor) + + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: TVTensorType, **kwargs) -> TVTensorType: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. - If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of - ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``. - Args: wrappee (Tensor): The tensor to convert. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference. - ``wrappee`` will be converted into the same subclass as ``like``. - kwargs: Can contain "format", "canvas_size" and "clamping_mode" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`. - Ignored otherwise. + ``wrappee`` will be converted into the same subclass as ``like`` + maintaining the same metadata as ``like``. + kwargs: Optional overrides for metadata. For BoundingBoxes: ``format``, ``canvas_size``, ``clamping_mode``. + For KeyPoints: ``canvas_size``. """ - if isinstance(like, BoundingBoxes): - return type(like)._wrap( - wrappee, - format=kwargs.get("format", like.format), - canvas_size=kwargs.get("canvas_size", like.canvas_size), - clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), - ) - elif isinstance(like, KeyPoints): - return type(like)._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) - else: - return wrappee.as_subclass(type(like)) + if not hasattr(like, "__wrap__"): + raise TypeError(f"Expected `like` to have a `__wrap__` method, but got {type(like)}") + + return like.__wrap__(wrappee, **kwargs) diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index 7aa3e50458d..2f67632f52d 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -116,6 +116,23 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_ bounding_boxes.clamping_mode = clamping_mode return bounding_boxes + def __wrap__( + self, + tensor: torch.Tensor, + *, + format: BoundingBoxFormat | str | None = None, + canvas_size: tuple[int, int] | None = None, + clamping_mode: CLAMPING_MODE_TYPE = None, + check_dims: bool | None = None, + ) -> BoundingBoxes: + return BoundingBoxes._wrap( + tensor, + format=format if format is not None else self.format, + canvas_size=canvas_size if canvas_size is not None else self.canvas_size, + clamping_mode=clamping_mode if clamping_mode is not None else self.clamping_mode, + check_dims=False, + ) + def __new__( cls, data: Any, @@ -153,17 +170,9 @@ def _wrap_output( ) if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes): - output = BoundingBoxes._wrap( - output, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False - ) + output = first_bbox_from_args.__wrap__(output) elif isinstance(output, (tuple, list)): - # This branch exists for chunk() and unbind() - output = type(output)( - BoundingBoxes._wrap( - part, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False - ) - for part in output - ) + output = type(output)(first_bbox_from_args.__wrap__(part) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index aede31ad7db..46aa6473ef2 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -67,6 +67,13 @@ def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims points.canvas_size = canvas_size return points + def __wrap__(self, tensor: torch.Tensor, *, canvas_size: tuple[int, int] | None = None) -> KeyPoints: + return KeyPoints._wrap( + tensor, + canvas_size=canvas_size if canvas_size is not None else self.canvas_size, + check_dims=False, + ) + def __new__( cls, data: Any, @@ -89,13 +96,11 @@ def _wrap_output( # Similar to BoundingBoxes._wrap_output(), see comment there. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) - canvas_size = first_keypoints_from_args.canvas_size if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): - output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False) + output = first_keypoints_from_args.__wrap__(output) elif isinstance(output, (tuple, list)): - # This branch exists for chunk() and unbind() - output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output) + output = type(output)(first_keypoints_from_args.__wrap__(part) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] diff --git a/torchvision/tv_tensors/_tv_tensor.py b/torchvision/tv_tensors/_tv_tensor.py index 9f07fc8f226..e8805b64339 100644 --- a/torchvision/tv_tensors/_tv_tensor.py +++ b/torchvision/tv_tensors/_tv_tensor.py @@ -9,6 +9,7 @@ from torch.types import _device, _dtype, _size from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass +from typing_extensions import Self D = TypeVar("D", bound="TVTensor") @@ -49,6 +50,9 @@ def _wrap_output( output = type(output)(cls._wrap_output(part, args, kwargs) for part in output) return output + def __wrap__(self, tensor: torch.Tensor) -> Self: + return tensor.as_subclass(type(self)) + @classmethod def __torch_function__( cls,