From 177f66efb13eb73f72c1a1eea9f1d6278d79ba4e Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Sun, 15 Feb 2026 22:13:39 +0530 Subject: [PATCH] fix resize_mask ignoring NEAREST_EXACT interpolation mode --- .../transforms/v2/functional/_geometry.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index a35de4ac95d..0e40791deb4 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -401,14 +401,24 @@ def __resize_image_pil_dispatch( return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) -def resize_mask(mask: torch.Tensor, size: Optional[list[int]], max_size: Optional[int] = None) -> torch.Tensor: +def resize_mask( + mask: torch.Tensor, + size: Optional[list[int]], + max_size: Optional[int] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, +) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) needs_squeeze = True else: needs_squeeze = False - output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + if isinstance(interpolation, int): + interpolation = InterpolationMode(interpolation) + if interpolation not in {InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT}: + interpolation = InterpolationMode.NEAREST + + output = resize_image(mask, size=size, interpolation=interpolation, max_size=max_size) if needs_squeeze: output = output.squeeze(0) @@ -420,7 +430,8 @@ def resize_mask(mask: torch.Tensor, size: Optional[list[int]], max_size: Optiona def _resize_mask_dispatch( inpt: tv_tensors.Mask, size: list[int], max_size: Optional[int] = None, **kwargs: Any ) -> tv_tensors.Mask: - output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) + interpolation = kwargs.pop("interpolation", InterpolationMode.NEAREST) + output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size, interpolation=interpolation) return tv_tensors.wrap(output, like=inpt)