Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 84 additions & 83 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,74 @@ def _affine_image_pil(
return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)


# TODO: Consider merging/unifying this with the bbox implementation
def _create_affine_matrix(
canvas_size: tuple[int, int],
angle: Union[int, float],
translate: list[float],
scale: float,
shear: list[float],
center: Optional[list[float]],
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, list[float]]:
"""Create transposed affine matrix for point transformation.

Returns:
transposed_affine_matrix: (3, 2) matrix for torch.matmul(points_homogeneous, matrix)
center: Resolved center coordinates [cx, cy]
"""
angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)

if center is None:
height, width = canvas_size
center = [width * 0.5, height * 0.5]

affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = torch.tensor(affine_vector, dtype=dtype, device=device).reshape(2, 3).T

return transposed_affine_matrix, center


def _apply_affine_expand(
transposed_affine_matrix: torch.Tensor,
canvas_size: tuple[int, int],
center: list[float],
angle: Union[int, float],
translate: list[float],
scale: float,
shear: list[float],
) -> tuple[torch.Tensor, tuple[int, int]]:
"""Compute translation and new canvas size for expand mode.

Returns:
translation: (1, 2) tensor to subtract from transformed points
new_canvas_size: (new_height, new_width)
"""
dtype = transposed_affine_matrix.dtype
device = transposed_affine_matrix.device
height, width = canvas_size

canvas_corners = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, float(height), 1.0],
[float(width), float(height), 1.0],
[float(width), 0.0, 1.0],
],
dtype=dtype,
device=device,
)
new_corners = torch.matmul(canvas_corners, transposed_affine_matrix)
translation = torch.amin(new_corners, dim=0, keepdim=True)

affine_vector = _get_inverse_affine_matrix(center, float(angle), translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)

return translation, (new_height, new_width)


def _affine_keypoints_with_expand(
keypoints: torch.Tensor,
canvas_size: tuple[int, int],
Expand All @@ -1005,53 +1072,20 @@ def _affine_keypoints_with_expand(
dtype = keypoints.dtype
device = keypoints.device

angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)

if center is None:
height, width = canvas_size
center = [width * 0.5, height * 0.5]

affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = (
torch.tensor(
affine_vector,
dtype=dtype,
device=device,
)
.reshape(2, 3)
.T
transposed_affine_matrix, center = _create_affine_matrix(
canvas_size, angle, translate, scale, shear, center, dtype, device
)

# 1) We transform points into a tensor of points with shape (N, 3), where N is the number of points.
# Transform points: add homogeneous coordinate and apply affine matrix
points = keypoints.reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)

if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = canvas_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, float(height), 1.0],
[float(width), float(height), 1.0],
[float(width), 0.0, 1.0],
],
dtype=dtype,
device=device,
translation, canvas_size = _apply_affine_expand(
transposed_affine_matrix, canvas_size, center, angle, translate, scale, shear
)
new_points = torch.matmul(points, transposed_affine_matrix)
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate keypoints
transformed_points.sub_(tr)
# Estimate meta-data for image with inverted=True
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)
transformed_points.sub_(translation)

out_keypoints = transformed_points.reshape(original_shape)
out_keypoints = out_keypoints.to(original_dtype)
Expand Down Expand Up @@ -1129,37 +1163,21 @@ def _affine_bounding_boxes_with_expand(
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True)
).reshape(-1, intermediate_shape)

angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
transposed_affine_matrix, center = _create_affine_matrix(
canvas_size, angle, translate, scale, shear, center, bounding_boxes.dtype, device
)

if center is None:
height, width = canvas_size
center = [width * 0.5, height * 0.5]

affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = (
torch.tensor(
affine_vector,
dtype=bounding_boxes.dtype,
device=device,
)
.reshape(2, 3)
.T
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
# Extract corner points from bounding boxes
if is_rotated:
points = bounding_boxes.reshape(-1, 2)
else:
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=bounding_boxes.dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix

# Transform points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:

# Recompute bounding boxes from transformed corner points
if is_rotated:
transformed_points = transformed_points.reshape(-1, 8)
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points)
Expand All @@ -1169,27 +1187,10 @@ def _affine_bounding_boxes_with_expand(
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)

if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = canvas_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, float(height), 1.0],
[float(width), float(height), 1.0],
[float(width), 0.0, 1.0],
],
dtype=bounding_boxes.dtype,
device=device,
translation, canvas_size = _apply_affine_expand(
transposed_affine_matrix, canvas_size, center, angle, translate, scale, shear
)
new_points = torch.matmul(points, transposed_affine_matrix)
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes.sub_(tr.repeat((1, 4 if is_rotated else 2)))
# Estimate meta-data for image with inverted=True
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
canvas_size = (new_height, new_width)
out_bboxes.sub_(translation.repeat((1, 4 if is_rotated else 2)))

out_bboxes = clamp_bounding_boxes(
out_bboxes, format=intermediate_format, canvas_size=canvas_size, clamping_mode=clamping_mode
Expand Down