From 75c07087a51b613e191fb6647f63dd076362a082 Mon Sep 17 00:00:00 2001 From: gabrielfruet Date: Fri, 20 Feb 2026 15:22:00 -0300 Subject: [PATCH 1/2] feat: standard and extensible wrapping --- torchvision/tv_tensors/__init__.py | 32 +++++++++++------------ torchvision/tv_tensors/_bounding_boxes.py | 18 ++++++++----- torchvision/tv_tensors/_keypoints.py | 12 ++++++--- torchvision/tv_tensors/_tv_tensor.py | 25 ++++++++++++++---- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index fed23e73364..a79a8dcf7a5 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,3 +1,4 @@ +from typing import TypeVar import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format @@ -7,33 +8,30 @@ from ._torch_function_helpers import set_return_type from ._tv_tensor import TVTensor from ._video import Video +from torchvision.tv_tensors._tv_tensor import TVTensor + + +TVTensorType = TypeVar("TVTensorType", bound=TVTensor) # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: TVTensorType, **kwargs) -> TVTensorType: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. - If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of - ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``. - Args: wrappee (Tensor): The tensor to convert. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference. - ``wrappee`` will be converted into the same subclass as ``like``. - kwargs: Can contain "format", "canvas_size" and "clamping_mode" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`. - Ignored otherwise. + ``wrappee`` will be converted into the same subclass as ``like`` + maintaining the same metadata as ``like``. + kwargs: Optional overrides for metadata. For BoundingBoxes: ``format``, ``canvas_size``, ``clamping_mode``. + For KeyPoints: ``canvas_size``. """ - if isinstance(like, BoundingBoxes): - return type(like)._wrap( - wrappee, - format=kwargs.get("format", like.format), - canvas_size=kwargs.get("canvas_size", like.canvas_size), - clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), + if not hasattr(like, "__wrap__"): + raise TypeError( + f"Expected `like` to have a `__wrap__` method, but got {type(like)}" ) - elif isinstance(like, KeyPoints): - return type(like)._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) - else: - return wrappee.as_subclass(type(like)) + + return like.__wrap__(wrappee, **kwargs) diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index 7aa3e50458d..d05997a771c 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -116,6 +116,15 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_ bounding_boxes.clamping_mode = clamping_mode return bounding_boxes + def __wrap__(self, tensor: torch.Tensor, *, format=None, canvas_size=None, clamping_mode=None) -> BoundingBoxes: + return BoundingBoxes._wrap( + tensor, + format=format if format is not None else self.format, + canvas_size=canvas_size if canvas_size is not None else self.canvas_size, + clamping_mode=clamping_mode if clamping_mode is not None else self.clamping_mode, + check_dims=False, + ) + def __new__( cls, data: Any, @@ -153,15 +162,10 @@ def _wrap_output( ) if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes): - output = BoundingBoxes._wrap( - output, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False - ) + output = first_bbox_from_args.__wrap__(output) elif isinstance(output, (tuple, list)): - # This branch exists for chunk() and unbind() output = type(output)( - BoundingBoxes._wrap( - part, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False - ) + first_bbox_from_args.__wrap__(part) for part in output ) return output diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index aede31ad7db..c894ed9cdf0 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -67,6 +67,13 @@ def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims points.canvas_size = canvas_size return points + def __wrap__(self, tensor: torch.Tensor, *, canvas_size=None) -> KeyPoints: + return KeyPoints._wrap( + tensor, + canvas_size=canvas_size if canvas_size is not None else self.canvas_size, + check_dims=False, + ) + def __new__( cls, data: Any, @@ -92,10 +99,9 @@ def _wrap_output( canvas_size = first_keypoints_from_args.canvas_size if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): - output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False) + output = first_keypoints_from_args.__wrap__(output) elif isinstance(output, (tuple, list)): - # This branch exists for chunk() and unbind() - output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output) + output = type(output)(first_keypoints_from_args.__wrap__(part) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] diff --git a/torchvision/tv_tensors/_tv_tensor.py b/torchvision/tv_tensors/_tv_tensor.py index 9f07fc8f226..25e41bf5a61 100644 --- a/torchvision/tv_tensors/_tv_tensor.py +++ b/torchvision/tv_tensors/_tv_tensor.py @@ -3,12 +3,16 @@ from collections.abc import Mapping, Sequence from typing import Any, Callable, TypeVar +from typing_extensions import Self import torch from torch._C import DisableTorchFunctionSubclass from torch.types import _device, _dtype, _size -from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass +from torchvision.tv_tensors._torch_function_helpers import ( + _FORCE_TORCHFUNCTION_SUBCLASS, + _must_return_subclass, +) D = TypeVar("D", bound="TVTensor") @@ -30,8 +34,12 @@ def _to_tensor( requires_grad: bool | None = None, ) -> torch.Tensor: if requires_grad is None: - requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False - return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + requires_grad = ( + data.requires_grad if isinstance(data, torch.Tensor) else False + ) + return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_( + requires_grad + ) @classmethod def _wrap_output( @@ -46,9 +54,14 @@ def _wrap_output( if isinstance(output, (tuple, list)): # Also handles things like namedtuples - output = type(output)(cls._wrap_output(part, args, kwargs) for part in output) + output = type(output)( + cls._wrap_output(part, args, kwargs) for part in output + ) return output + def __wrap__(self, tensor: torch.Tensor) -> Self: + return tensor.as_subclass(type(self)) + @classmethod def __torch_function__( cls, @@ -79,7 +92,9 @@ def __torch_function__( output = func(*args, **kwargs or dict()) must_return_subclass = _must_return_subclass() - if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): + if must_return_subclass or ( + func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls) + ): # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails # in test_to_tv_tensor_reference(). # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in From 5b0311df22cf7345db3029c8082df7b366ef60b3 Mon Sep 17 00:00:00 2001 From: gabrielfruet Date: Mon, 23 Feb 2026 11:15:41 -0300 Subject: [PATCH 2/2] fix: format and typing --- torchvision/tv_tensors/__init__.py | 8 ++++---- torchvision/tv_tensors/_bounding_boxes.py | 15 ++++++++++----- torchvision/tv_tensors/_keypoints.py | 3 +-- torchvision/tv_tensors/_tv_tensor.py | 23 ++++++----------------- 4 files changed, 21 insertions(+), 28 deletions(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index a79a8dcf7a5..745340822d1 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,4 +1,5 @@ from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format @@ -7,9 +8,10 @@ from ._mask import Mask from ._torch_function_helpers import set_return_type from ._tv_tensor import TVTensor -from ._video import Video from torchvision.tv_tensors._tv_tensor import TVTensor +from ._video import Video + TVTensorType = TypeVar("TVTensorType", bound=TVTensor) @@ -30,8 +32,6 @@ def wrap(wrappee: torch.Tensor, *, like: TVTensorType, **kwargs) -> TVTensorType For KeyPoints: ``canvas_size``. """ if not hasattr(like, "__wrap__"): - raise TypeError( - f"Expected `like` to have a `__wrap__` method, but got {type(like)}" - ) + raise TypeError(f"Expected `like` to have a `__wrap__` method, but got {type(like)}") return like.__wrap__(wrappee, **kwargs) diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index d05997a771c..2f67632f52d 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -116,7 +116,15 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_ bounding_boxes.clamping_mode = clamping_mode return bounding_boxes - def __wrap__(self, tensor: torch.Tensor, *, format=None, canvas_size=None, clamping_mode=None) -> BoundingBoxes: + def __wrap__( + self, + tensor: torch.Tensor, + *, + format: BoundingBoxFormat | str | None = None, + canvas_size: tuple[int, int] | None = None, + clamping_mode: CLAMPING_MODE_TYPE = None, + check_dims: bool | None = None, + ) -> BoundingBoxes: return BoundingBoxes._wrap( tensor, format=format if format is not None else self.format, @@ -164,10 +172,7 @@ def _wrap_output( if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes): output = first_bbox_from_args.__wrap__(output) elif isinstance(output, (tuple, list)): - output = type(output)( - first_bbox_from_args.__wrap__(part) - for part in output - ) + output = type(output)(first_bbox_from_args.__wrap__(part) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index c894ed9cdf0..46aa6473ef2 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -67,7 +67,7 @@ def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims points.canvas_size = canvas_size return points - def __wrap__(self, tensor: torch.Tensor, *, canvas_size=None) -> KeyPoints: + def __wrap__(self, tensor: torch.Tensor, *, canvas_size: tuple[int, int] | None = None) -> KeyPoints: return KeyPoints._wrap( tensor, canvas_size=canvas_size if canvas_size is not None else self.canvas_size, @@ -96,7 +96,6 @@ def _wrap_output( # Similar to BoundingBoxes._wrap_output(), see comment there. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) - canvas_size = first_keypoints_from_args.canvas_size if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): output = first_keypoints_from_args.__wrap__(output) diff --git a/torchvision/tv_tensors/_tv_tensor.py b/torchvision/tv_tensors/_tv_tensor.py index 25e41bf5a61..e8805b64339 100644 --- a/torchvision/tv_tensors/_tv_tensor.py +++ b/torchvision/tv_tensors/_tv_tensor.py @@ -3,16 +3,13 @@ from collections.abc import Mapping, Sequence from typing import Any, Callable, TypeVar -from typing_extensions import Self import torch from torch._C import DisableTorchFunctionSubclass from torch.types import _device, _dtype, _size -from torchvision.tv_tensors._torch_function_helpers import ( - _FORCE_TORCHFUNCTION_SUBCLASS, - _must_return_subclass, -) +from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass +from typing_extensions import Self D = TypeVar("D", bound="TVTensor") @@ -34,12 +31,8 @@ def _to_tensor( requires_grad: bool | None = None, ) -> torch.Tensor: if requires_grad is None: - requires_grad = ( - data.requires_grad if isinstance(data, torch.Tensor) else False - ) - return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_( - requires_grad - ) + requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False + return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) @classmethod def _wrap_output( @@ -54,9 +47,7 @@ def _wrap_output( if isinstance(output, (tuple, list)): # Also handles things like namedtuples - output = type(output)( - cls._wrap_output(part, args, kwargs) for part in output - ) + output = type(output)(cls._wrap_output(part, args, kwargs) for part in output) return output def __wrap__(self, tensor: torch.Tensor) -> Self: @@ -92,9 +83,7 @@ def __torch_function__( output = func(*args, **kwargs or dict()) must_return_subclass = _must_return_subclass() - if must_return_subclass or ( - func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls) - ): + if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)): # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails # in test_to_tv_tensor_reference(). # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in