From 273b95cfc3ca8a7bf4814be9b541a4f39da0e745 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 30 Oct 2025 11:36:02 -0700 Subject: [PATCH 1/4] update randomcrop --- torchvision/transforms/transforms.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index e33b3e28194..9f2e3155732 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -631,12 +631,13 @@ class RandomCrop(torch.nn.Module): """ @staticmethod - def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int, int]: + def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[torch.Generator] = None) -> tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. + generator (torch.Generator, optional): Random number generator. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. @@ -650,11 +651,11 @@ def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,)).item() - j = torch.randint(0, w - tw + 1, size=(1,)).item() + i = torch.randint(0, h - th + 1, size=(1,), generator=generator).item() + j = torch.randint(0, w - tw + 1, size=(1,), generator=generator).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant", generator=None): super().__init__() _log_api_usage_once(self) @@ -664,6 +665,7 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode + self.generator = generator def forward(self, img): """ @@ -686,7 +688,7 @@ def forward(self, img): padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) - i, j, h, w = self.get_params(img, self.size) + i, j, h, w = self.get_params(img, self.size, self.generator) return F.crop(img, i, j, h, w) From 52d3353428ef47f0c1b39696ec54d4384c2f4933 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 15 Jan 2026 16:25:25 -0800 Subject: [PATCH 2/4] Use torch.thread_safe_generator --- torchvision/transforms/transforms.py | 12 ++-- torchvision/transforms/v2/_geometry.py | 83 ++++++++++++++++--------- torchvision/transforms/v2/_transform.py | 3 +- 3 files changed, 59 insertions(+), 39 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9f2e3155732..e33b3e28194 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -631,13 +631,12 @@ class RandomCrop(torch.nn.Module): """ @staticmethod - def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[torch.Generator] = None) -> tuple[int, int, int, int]: + def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: img (PIL Image or Tensor): Image to be cropped. output_size (tuple): Expected output size of the crop. - generator (torch.Generator, optional): Random number generator. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. @@ -651,11 +650,11 @@ def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[to if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,), generator=generator).item() - j = torch.randint(0, w - tw + 1, size=(1,), generator=generator).item() + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant", generator=None): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() _log_api_usage_once(self) @@ -665,7 +664,6 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode - self.generator = generator def forward(self, img): """ @@ -688,7 +686,7 @@ def forward(self, img): padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) - i, j, h, w = self.get_params(img, self.size, self.generator) + i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index c88f3d9a504..e0b8074cbe3 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -281,13 +281,16 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) area = height * width + g = torch.thread_safe_generator() + log_ratio = self._log_ratio for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -295,8 +298,8 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() + i = torch.randint(0, height - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, width - w + 1, size=(1,), generator=g).item() break else: # Fallback to central crop @@ -547,11 +550,13 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + g = torch.thread_safe_generator() + + r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) @@ -628,7 +633,8 @@ def __init__( self.center = center def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() return dict(angle=angle) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -728,26 +734,28 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) - tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) - ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item())) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item() if len(self.shear) == 4: - shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -885,13 +893,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: padding = [pad_left, pad_top, pad_right, pad_bottom] needs_pad = any(padding) + g = torch.thread_safe_generator() + needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g))) if padded_width > cropped_width else (False, 0) ) @@ -970,21 +980,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: half_width = width // 2 bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 + + g = torch.thread_safe_generator() + topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -1065,7 +1078,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - dx = torch.rand(1, 1, height, width) * 2 - 1 + g = torch.thread_safe_generator() + + dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd @@ -1074,7 +1089,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / width - dy = torch.rand(1, 1, height, width) * 2 - 1 + dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd @@ -1157,16 +1172,18 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) + g = torch.thread_safe_generator() + while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g)) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h @@ -1174,7 +1191,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue # check for 0 area crops - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w @@ -1206,7 +1223,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: - if len(params) < 1: return inpt @@ -1276,7 +1292,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + g = torch.thread_safe_generator() + + scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) @@ -1341,7 +1359,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + g = torch.thread_safe_generator() + + min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) @@ -1418,7 +1438,8 @@ def __init__( self.antialias = antialias def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) + g = torch.thread_safe_generator() + size = int(torch.randint(self.min_size, self.max_size, (), generator=g)) return dict(size=[size]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..ae02e736f05 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any: self.check_inputs(flat_inputs) - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs) From 30b165a11bb49a3fff9ddc686cba3250075977fc Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Wed, 25 Feb 2026 14:19:48 -0800 Subject: [PATCH 3/4] add unit test mocking torch.thread_safe_generator --- test/test_transforms_v2.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 759f1f44643..009da11d5b0 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -8120,3 +8120,61 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + +class TestThreadSafeGenerator: + """Test that transforms correctly use torch.thread_safe_generator(). + + For multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default process global RNG, + i.e. for a multiprocessing worker the RNG of that process. + For thread workers, it returns a thread-local torch.Generator. + """ + + TRANSFORMS = [ + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomResizedCrop(size=(24, 24)), + transforms.RandomRotation(degrees=10), + transforms.RandomAffine(degrees=10), + transforms.RandomCrop(size=(24, 24), pad_if_needed=True), + transforms.RandomPerspective(p=1.0), + transforms.RandomErasing(p=1.0), + transforms.ScaleJitter(target_size=(24, 24)), + ] + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_worker_uses_global_rng(self, transform): + """In multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default global (per-process) RNG. Mimic two + workers with different seeds and verify they produce different results.""" + image = make_image((32, 32)) + + with mock.patch("torch.thread_safe_generator", return_value=None): + torch.manual_seed(0) + result_worker0 = transform(image) + + with mock.patch("torch.thread_safe_generator", return_value=None): + torch.manual_seed(1) + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_worker_uses_thread_local_generator(self, transform): + """In thread workers, thread_safe_generator() returns a thread-local + Generator. Mimic two workers with differently seeded generators + and verify they produce different results.""" + image = make_image((32, 32)) + + g0 = torch.Generator() + g0.manual_seed(0) + with mock.patch("torch.thread_safe_generator", return_value=g0): + result_worker0 = transform(image) + + g1 = torch.Generator() + g1.manual_seed(1) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) From 6277f11baa749bcd4733a76d2b790a122d81b74e Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Wed, 25 Feb 2026 14:46:41 -0800 Subject: [PATCH 4/4] run ufmt, update mp unit test --- test/test_transforms_v2.py | 38 ++++++++++++-------------- torchvision/transforms/v2/_geometry.py | 1 - 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 009da11d5b0..fa9288c6b46 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -15,11 +15,9 @@ import numpy as np import PIL.Image import pytest - import torch import torchvision.ops import torchvision.transforms.v2 as transforms - from common_utils import ( assert_equal, cache, @@ -40,14 +38,12 @@ needs_cvcuda, set_rng_seed, ) - from torch import nn from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors from torchvision.ops.boxes import box_iou - from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F @@ -63,7 +59,6 @@ ) from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -8132,8 +8127,6 @@ class TestThreadSafeGenerator: """ TRANSFORMS = [ - transforms.RandomHorizontalFlip(p=0.5), - transforms.RandomVerticalFlip(p=0.5), transforms.RandomResizedCrop(size=(24, 24)), transforms.RandomRotation(degrees=10), transforms.RandomAffine(degrees=10), @@ -8143,22 +8136,27 @@ class TestThreadSafeGenerator: transforms.ScaleJitter(target_size=(24, 24)), ] - @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) - def test_multiprocessing_worker_uses_global_rng(self, transform): - """In multiprocessing workers, thread_safe_generator() returns None, - so transforms use the default global (per-process) RNG. Mimic two - workers with different seeds and verify they produce different results.""" - image = make_image((32, 32)) + class TransformDataset(torch.utils.data.Dataset): + def __init__(self, size, transform): + self.size = size + self.transform = transform + self.image = make_image((32, 32)) - with mock.patch("torch.thread_safe_generator", return_value=None): - torch.manual_seed(0) - result_worker0 = transform(image) + def __getitem__(self, idx): + return self.transform(self.image) - with mock.patch("torch.thread_safe_generator", return_value=None): - torch.manual_seed(1) - result_worker1 = transform(image) + def __len__(self): + return self.size - assert not torch.equal(result_worker0, result_worker1) + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_workers(self, transform): + """With multiprocessing DataLoader workers, thread_safe_generator() + returns None and transforms use the per-process global RNG. + Each worker gets a different seed, so results should differ.""" + dataset = self.TransformDataset(size=2, transform=transform) + dl = DataLoader(dataset, batch_size=1, num_workers=2) + batch0, batch1 = list(dl) + assert not torch.equal(batch0, batch1) @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) def test_thread_worker_uses_thread_local_generator(self, transform): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index e0b8074cbe3..eedda4c4c6a 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -6,7 +6,6 @@ import PIL.Image import torch - from torchvision import transforms as _transforms, tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs