diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 759f1f44643..fa9288c6b46 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -15,11 +15,9 @@ import numpy as np import PIL.Image import pytest - import torch import torchvision.ops import torchvision.transforms.v2 as transforms - from common_utils import ( assert_equal, cache, @@ -40,14 +38,12 @@ needs_cvcuda, set_rng_seed, ) - from torch import nn from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors from torchvision.ops.boxes import box_iou - from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F @@ -63,7 +59,6 @@ ) from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal - # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -8120,3 +8115,64 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + +class TestThreadSafeGenerator: + """Test that transforms correctly use torch.thread_safe_generator(). + + For multiprocessing workers, thread_safe_generator() returns None, + so transforms use the default process global RNG, + i.e. for a multiprocessing worker the RNG of that process. + For thread workers, it returns a thread-local torch.Generator. + """ + + TRANSFORMS = [ + transforms.RandomResizedCrop(size=(24, 24)), + transforms.RandomRotation(degrees=10), + transforms.RandomAffine(degrees=10), + transforms.RandomCrop(size=(24, 24), pad_if_needed=True), + transforms.RandomPerspective(p=1.0), + transforms.RandomErasing(p=1.0), + transforms.ScaleJitter(target_size=(24, 24)), + ] + + class TransformDataset(torch.utils.data.Dataset): + def __init__(self, size, transform): + self.size = size + self.transform = transform + self.image = make_image((32, 32)) + + def __getitem__(self, idx): + return self.transform(self.image) + + def __len__(self): + return self.size + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_multiprocessing_workers(self, transform): + """With multiprocessing DataLoader workers, thread_safe_generator() + returns None and transforms use the per-process global RNG. + Each worker gets a different seed, so results should differ.""" + dataset = self.TransformDataset(size=2, transform=transform) + dl = DataLoader(dataset, batch_size=1, num_workers=2) + batch0, batch1 = list(dl) + assert not torch.equal(batch0, batch1) + + @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) + def test_thread_worker_uses_thread_local_generator(self, transform): + """In thread workers, thread_safe_generator() returns a thread-local + Generator. Mimic two workers with differently seeded generators + and verify they produce different results.""" + image = make_image((32, 32)) + + g0 = torch.Generator() + g0.manual_seed(0) + with mock.patch("torch.thread_safe_generator", return_value=g0): + result_worker0 = transform(image) + + g1 = torch.Generator() + g1.manual_seed(1) + with mock.patch("torch.thread_safe_generator", return_value=g1): + result_worker1 = transform(image) + + assert not torch.equal(result_worker0, result_worker1) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index c88f3d9a504..eedda4c4c6a 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -6,7 +6,6 @@ import PIL.Image import torch - from torchvision import transforms as _transforms, tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs @@ -281,13 +280,16 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) area = height * width + g = torch.thread_safe_generator() + log_ratio = self._log_ratio for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=g, ) ).item() @@ -295,8 +297,8 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() + i = torch.randint(0, height - h + 1, size=(1,), generator=g).item() + j = torch.randint(0, width - w + 1, size=(1,), generator=g).item() break else: # Fallback to central crop @@ -547,11 +549,13 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + g = torch.thread_safe_generator() + + r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) @@ -628,7 +632,8 @@ def __init__( self.center = center def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() return dict(angle=angle) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -728,26 +733,28 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + g = torch.thread_safe_generator() + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) - tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) - ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item())) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item() if len(self.shear) == 4: - shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -885,13 +892,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: padding = [pad_left, pad_top, pad_right, pad_bottom] needs_pad = any(padding) + g = torch.thread_safe_generator() + needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g))) if padded_width > cropped_width else (False, 0) ) @@ -970,21 +979,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: half_width = width // 2 bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 + + g = torch.thread_safe_generator() + topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(0, bound_height, size=(1,), generator=g)), ] botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=g)), + int(torch.randint(height - bound_height, height, size=(1,), generator=g)), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -1065,7 +1077,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) - dx = torch.rand(1, 1, height, width) * 2 - 1 + g = torch.thread_safe_generator() + + dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd @@ -1074,7 +1088,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / width - dy = torch.rand(1, 1, height, width) * 2 - 1 + dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd @@ -1157,16 +1171,18 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) + g = torch.thread_safe_generator() + while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g)) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h @@ -1174,7 +1190,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: continue # check for 0 area crops - r = torch.rand(2) + r = torch.rand(2, generator=g) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w @@ -1206,7 +1222,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: - if len(params) < 1: return inpt @@ -1276,7 +1291,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + g = torch.thread_safe_generator() + + scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) @@ -1341,7 +1358,9 @@ def __init__( def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + g = torch.thread_safe_generator() + + min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) @@ -1418,7 +1437,8 @@ def __init__( self.antialias = antialias def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) + g = torch.thread_safe_generator() + size = int(torch.randint(self.min_size, self.max_size, (), generator=g)) return dict(size=[size]) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..ae02e736f05 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any: self.check_inputs(flat_inputs) - if torch.rand(1) >= self.p: + g = torch.thread_safe_generator() + if torch.rand(1, generator=g) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs)