diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..17ccbf02fe0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2113,12 +2113,13 @@ def perspective_bounding_boxes( original_dtype = bounding_boxes.dtype is_rotated = tv_tensors.is_rotated_bounding_format(format) intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY - # TODO: first cast to float if bbox is int64 before convert_bounding_box_format + need_cast = not bounding_boxes.is_floating_point() + bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone() bounding_boxes = ( - convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format) + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True) ).reshape(-1, 8 if is_rotated else 4) - dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + dtype = bounding_boxes.dtype device = bounding_boxes.device # perspective_coeffs are computed as endpoint -> start point @@ -2430,18 +2431,21 @@ def elastic_bounding_boxes( # TODO: add in docstring about approximation we are doing for grid inversion device = bounding_boxes.device - dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + original_dtype = bounding_boxes.dtype is_rotated = tv_tensors.is_rotated_bounding_format(format) + original_shape = bounding_boxes.shape + need_cast = not bounding_boxes.is_floating_point() + bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone() + dtype = bounding_boxes.dtype + if displacement.dtype != dtype or displacement.device != device: displacement = displacement.to(dtype=dtype, device=device) - original_shape = bounding_boxes.shape - # TODO: first cast to float if bbox is int64 before convert_bounding_box_format intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY bounding_boxes = ( - convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format) + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True) ).reshape(-1, 5 if is_rotated else 4) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) @@ -2473,10 +2477,12 @@ def elastic_bounding_boxes( out_bboxes, format=intermediate_format, canvas_size=canvas_size, clamping_mode=clamping_mode ) - return convert_bounding_box_format( - out_bboxes, old_format=intermediate_format, new_format=format, inplace=False + out_bboxes = convert_bounding_box_format( + out_bboxes, old_format=intermediate_format, new_format=format, inplace=True ).reshape(original_shape) + return out_bboxes.to(original_dtype) + @_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) def _elastic_bounding_boxes_dispatch(