Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,10 +2442,22 @@ def elastic_bounding_boxes(
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
).reshape(-1, 5 if is_rotated else 4)

id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# Create a grid with (H+1, W+1) to handle boundary coordinates (e.g., x=W, y=H)
extended_size = (canvas_size[0] + 1, canvas_size[1] + 1)
id_grid = _create_identity_grid(extended_size, device=device, dtype=dtype)

# Pad displacement to match extended grid size (replicate edge values)
padded_displacement = torch.nn.functional.pad(
displacement.permute(0, 3, 1, 2), # NHWC -> NCHW format
(0, 1, 0, 1), # pad right and bottom by 1
mode="replicate",
).permute(
0, 2, 3, 1
) # back to NHWC

# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)
inv_grid = id_grid.sub_(padded_displacement)

# Get points from bboxes
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
Expand Down
Loading