Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
import numpy as np
import PIL.Image
import pytest

import torch
import torchvision.ops
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_equal,
cache,
Expand All @@ -40,14 +38,12 @@
needs_cvcuda,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou

from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
from torchvision.transforms.v2 import functional as F
Expand All @@ -63,7 +59,6 @@
)
from torchvision.transforms.v2.functional._utils import _get_kernel, _import_cvcuda, _register_kernel_internal


# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]

Expand Down Expand Up @@ -8120,3 +8115,64 @@ def test_different_sizes(self, make_input1, make_input2, query):
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])


class TestThreadSafeGenerator:
"""Test that transforms correctly use torch.thread_safe_generator().

For multiprocessing workers, thread_safe_generator() returns None,
so transforms use the default process global RNG,
i.e. for a multiprocessing worker the RNG of that process.
For thread workers, it returns a thread-local torch.Generator.
"""

TRANSFORMS = [
transforms.RandomResizedCrop(size=(24, 24)),
transforms.RandomRotation(degrees=10),
transforms.RandomAffine(degrees=10),
transforms.RandomCrop(size=(24, 24), pad_if_needed=True),
transforms.RandomPerspective(p=1.0),
transforms.RandomErasing(p=1.0),
transforms.ScaleJitter(target_size=(24, 24)),
]

class TransformDataset(torch.utils.data.Dataset):
def __init__(self, size, transform):
self.size = size
self.transform = transform
self.image = make_image((32, 32))

def __getitem__(self, idx):
return self.transform(self.image)

def __len__(self):
return self.size

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_multiprocessing_workers(self, transform):
"""With multiprocessing DataLoader workers, thread_safe_generator()
returns None and transforms use the per-process global RNG.
Each worker gets a different seed, so results should differ."""
dataset = self.TransformDataset(size=2, transform=transform)
dl = DataLoader(dataset, batch_size=1, num_workers=2)
batch0, batch1 = list(dl)
assert not torch.equal(batch0, batch1)

@pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__)
def test_thread_worker_uses_thread_local_generator(self, transform):
"""In thread workers, thread_safe_generator() returns a thread-local
Generator. Mimic two workers with differently seeded generators
and verify they produce different results."""
image = make_image((32, 32))

g0 = torch.Generator()
g0.manual_seed(0)
with mock.patch("torch.thread_safe_generator", return_value=g0):
result_worker0 = transform(image)

g1 = torch.Generator()
g1.manual_seed(1)
with mock.patch("torch.thread_safe_generator", return_value=g1):
result_worker1 = transform(image)

assert not torch.equal(result_worker0, result_worker1)
84 changes: 52 additions & 32 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import PIL.Image
import torch

from torchvision import transforms as _transforms, tv_tensors
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
Expand Down Expand Up @@ -281,22 +280,25 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)
area = height * width

g = torch.thread_safe_generator()

log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
generator=g,
)
).item()

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))

if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
i = torch.randint(0, height - h + 1, size=(1,), generator=g).item()
j = torch.randint(0, width - w + 1, size=(1,), generator=g).item()
break
else:
# Fallback to central crop
Expand Down Expand Up @@ -547,11 +549,13 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
g = torch.thread_safe_generator()

r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)

r = torch.rand(2)
r = torch.rand(2, generator=g)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
Expand Down Expand Up @@ -628,7 +632,8 @@ def __init__(
self.center = center

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
g = torch.thread_safe_generator()
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
return dict(angle=angle)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down Expand Up @@ -728,26 +733,28 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)

angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
g = torch.thread_safe_generator()

angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item()))
translate = (tx, ty)
else:
translate = (0, 0)

if self.scale is not None:
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
else:
scale = 1.0

shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item()
if len(self.shear) == 4:
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item()

shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
Expand Down Expand Up @@ -885,13 +892,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)

g = torch.thread_safe_generator()

needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g)))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g)))
if padded_width > cropped_width
else (False, 0)
)
Expand Down Expand Up @@ -970,21 +979,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1

g = torch.thread_safe_generator()

topleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
int(torch.randint(0, bound_width, size=(1,), generator=g)),
int(torch.randint(0, bound_height, size=(1,), generator=g)),
]
topright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
int(torch.randint(0, bound_height, size=(1,), generator=g)),
]
botright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
]
botleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
int(torch.randint(0, bound_width, size=(1,), generator=g)),
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
Expand Down Expand Up @@ -1065,7 +1077,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
height, width = query_size(flat_inputs)

dx = torch.rand(1, 1, height, width) * 2 - 1
g = torch.thread_safe_generator()

dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
Expand All @@ -1074,7 +1088,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / width

dy = torch.rand(1, 1, height, width) * 2 - 1
dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
Expand Down Expand Up @@ -1157,24 +1171,26 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs)
bboxes = get_bounding_boxes(flat_inputs)

g = torch.thread_safe_generator()

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
r = torch.rand(2, generator=g)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
Expand Down Expand Up @@ -1206,7 +1222,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:

if len(params) < 1:
return inpt

Expand Down Expand Up @@ -1276,7 +1291,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_height, orig_width = query_size(flat_inputs)

scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
g = torch.thread_safe_generator()

scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
Expand Down Expand Up @@ -1341,7 +1358,9 @@ def __init__(
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
orig_height, orig_width = query_size(flat_inputs)

min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
g = torch.thread_safe_generator()

min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))]
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
Expand Down Expand Up @@ -1418,7 +1437,8 @@ def __init__(
self.antialias = antialias

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
g = torch.thread_safe_generator()
size = int(torch.randint(self.min_size, self.max_size, (), generator=g))
return dict(size=[size])

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any:

self.check_inputs(flat_inputs)

if torch.rand(1) >= self.p:
g = torch.thread_safe_generator()
if torch.rand(1, generator=g) >= self.p:
return inputs

needs_transform_list = self._needs_transform_list(flat_inputs)
Expand Down
Loading