diff --git a/test/test_models_detection_export.py b/test/test_models_detection_export.py new file mode 100644 index 00000000000..3a12c293ca2 --- /dev/null +++ b/test/test_models_detection_export.py @@ -0,0 +1,283 @@ +import os + +import pytest +import torch +from common_utils import set_rng_seed +from torch.export import Dim, export +from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn + + +def _get_image(input_shape, device="cpu"): + GRACE_HOPPER = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" + ) + if os.path.exists(GRACE_HOPPER): + from PIL import Image + from torchvision import transforms + + img = Image.open(GRACE_HOPPER) + w, h = img.size + img = img.crop((0, 0, w, w)) + img = img.resize(input_shape[1:3]) + return transforms.ToTensor()(img).to(device=device) + return torch.rand(input_shape, device=device) + + +def _fpn_dynamic_shapes(): + """Dynamic shapes constrained to multiples of 64 for FPN compatibility. + + Strided convolutions in the backbone specialize on the parity of + ceil_to_32(dim)/32. Using multiples of 64 ensures consistent even + block counts, avoiding shape guards that would reject valid inputs. + """ + _h = Dim("_h", min=4, max=21) + _w = Dim("_w", min=4, max=21) + return {"images": [{1: 64 * _h, 2: 64 * _w}]} + + +@pytest.fixture(scope="module") +def fasterrcnn_model(): + """Load and pre-initialize FasterRCNN once for all tests in the module. + + _skip_resize=True bypasses the aspect-ratio-preserving resize in + GeneralizedRCNNTransform, which creates shape guards that specialize + to the tracing input's orientation (h<=w vs h>w). The caller is + responsible for pre-sizing inputs to the expected range. + """ + set_rng_seed(0) + model = fasterrcnn_mobilenet_v3_large_fpn( + num_classes=50, + weights_backbone=None, + box_score_thresh=0.02076, + _skip_resize=True, + ) + model.eval() + with torch.no_grad(): + _ = model([torch.randn(3, 256, 256)]) + return model + + +@pytest.fixture(scope="module") +def real_image(): + """Load the same real image used by test_detection_model.""" + return _get_image((3, 320, 320)) + + +class TestDetectionExport: + """Tests for torch.export of detection models. + + Verifies that export produces correct results matching eager mode, + works with dynamic shapes, handles edge cases, and supports both + strict=True and strict=False modes. + """ + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_succeeds(self, fasterrcnn_model, strict): + """Export should succeed with dynamic H/W shapes.""" + with torch.no_grad(): + ep = export( + fasterrcnn_model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + assert ep is not None + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_matches_eager_real_image(self, fasterrcnn_model, real_image, strict): + """Exported model output should match eager on the same real image.""" + with torch.no_grad(): + ep = export( + fasterrcnn_model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + inp = [real_image.clone()] + with torch.no_grad(): + eager_out = fasterrcnn_model(inp) + export_out = ep.module()([real_image.clone()]) + + assert len(eager_out) == 1 and len(export_out) == 1 + for key in ("boxes", "scores", "labels"): + assert key in export_out[0], f"Missing key '{key}' in export output" + + # With random backbone weights, scores are near-zero and NMS ordering + # is sensitive to floating-point differences between eager and export. + # Only compare detections with confident scores; otherwise just verify + # structural correctness. + eager_confident = eager_out[0]["scores"] > 0.1 + export_confident = export_out[0]["scores"] > 0.1 + if eager_confident.sum() > 0 and eager_confident.sum() == export_confident.sum(): + torch.testing.assert_close( + eager_out[0]["boxes"][eager_confident], + export_out[0]["boxes"][export_confident], + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + eager_out[0]["scores"][eager_confident], + export_out[0]["scores"][export_confident], + atol=1e-6, + rtol=1e-6, + ) + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_matches_eager_random_input(self, fasterrcnn_model, strict): + """Exported model should match eager on the same random input used by test_detection_model.""" + set_rng_seed(0) + with torch.no_grad(): + ep = export( + fasterrcnn_model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + x = torch.rand(3, 320, 320) + with torch.no_grad(): + eager_out = fasterrcnn_model([x.clone()]) + export_out = ep.module()([x.clone()]) + + for key in ("boxes", "scores", "labels"): + assert key in export_out[0], f"Missing key '{key}' in export output" + + # Only compare confident detections (see test_export_matches_eager_real_image) + eager_confident = eager_out[0]["scores"] > 0.1 + export_confident = export_out[0]["scores"] > 0.1 + if eager_confident.sum() > 0 and eager_confident.sum() == export_confident.sum(): + torch.testing.assert_close( + eager_out[0]["boxes"][eager_confident], + export_out[0]["boxes"][export_confident], + atol=1e-4, + rtol=1e-4, + ) + + @pytest.mark.parametrize("strict", [False, True]) + @pytest.mark.parametrize("h_val,w_val", [(256, 512), (384, 320), (448, 640), (256, 256)]) + def test_export_dynamic_shapes(self, fasterrcnn_model, h_val, w_val, strict): + """Exported model should run on various input sizes without error.""" + with torch.no_grad(): + ep = export( + fasterrcnn_model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + set_rng_seed(42) + x = torch.rand(3, h_val, w_val) + with torch.no_grad(): + eager_out = fasterrcnn_model([x.clone()]) + export_out = ep.module()([x.clone()]) + + assert len(eager_out) == 1 and len(export_out) == 1 + for key in ("boxes", "scores", "labels"): + assert key in export_out[0], f"Missing key '{key}' in export output" + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_zero_detections(self, fasterrcnn_model, strict): + """Exported model should handle the case where NMS produces 0 detections.""" + # Use default thresholds — random noise should produce 0 detections + model = fasterrcnn_mobilenet_v3_large_fpn(num_classes=50, weights_backbone=None, _skip_resize=True) + model.eval() + with torch.no_grad(): + _ = model([torch.randn(3, 256, 256)]) + + with torch.no_grad(): + ep = export( + model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + set_rng_seed(0) + x = torch.rand(3, 320, 512) + with torch.no_grad(): + eager_out = model([x.clone()]) + export_out = ep.module()([x.clone()]) + + assert len(eager_out[0]["boxes"]) == len(export_out[0]["boxes"]) + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_many_detections(self, strict): + """Exported model with lowered thresholds should produce many detections.""" + model = fasterrcnn_mobilenet_v3_large_fpn(num_classes=50, weights_backbone=None, _skip_resize=True) + model.eval() + model.rpn.score_thresh = 0.0 + model.rpn._pre_nms_top_n = {"training": 2000, "testing": 100} + model.rpn._post_nms_top_n = {"training": 2000, "testing": 100} + model.roi_heads.score_thresh = 0.0 + model.roi_heads.detections_per_img = 20 + + with torch.no_grad(): + _ = model([torch.randn(3, 256, 256)]) + + with torch.no_grad(): + ep = export( + model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + set_rng_seed(42) + x = torch.rand(3, 320, 512) + with torch.no_grad(): + eager_out = model([x.clone()]) + export_out = ep.module()([x.clone()]) + + n_eager = len(eager_out[0]["boxes"]) + n_export = len(export_out[0]["boxes"]) + assert n_eager > 0, "Expected detections with lowered thresholds" + assert n_export > 0, "Export should also produce detections" + # With random weights, NMS is sensitive to floating-point differences + # so we verify count and structure rather than exact coordinates + assert n_eager == n_export + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_zero_detections_structure(self, strict): + """Exported model should produce correctly-shaped empty tensors when NMS finds nothing.""" + model = fasterrcnn_mobilenet_v3_large_fpn(num_classes=50, weights_backbone=None, _skip_resize=True) + model.eval() + with torch.no_grad(): + _ = model([torch.randn(3, 256, 256)]) + + with torch.no_grad(): + ep = export( + model, + ([torch.randn(3, 256, 320)],), + dynamic_shapes=_fpn_dynamic_shapes(), + strict=strict, + ) + + set_rng_seed(0) + x = torch.rand(3, 384, 512) + with torch.no_grad(): + eager_out = model([x.clone()]) + export_out = ep.module()([x.clone()]) + + assert eager_out[0]["boxes"].shape[0] == 0, "Expected 0 eager detections with default thresholds" + assert export_out[0]["boxes"].shape == torch.Size([0, 4]) + assert export_out[0]["scores"].shape == torch.Size([0]) + assert export_out[0]["labels"].shape == torch.Size([0]) + + @pytest.mark.parametrize("strict", [False, True]) + def test_export_static_shapes(self, fasterrcnn_model, strict): + """Export with fully static shapes should also work.""" + with torch.no_grad(): + ep = export( + fasterrcnn_model, + ([torch.randn(3, 300, 300)],), + strict=strict, + ) + + set_rng_seed(0) + x = torch.rand(3, 300, 300) + with torch.no_grad(): + eager_out = fasterrcnn_model([x.clone()]) + export_out = ep.module()([x.clone()]) + + assert len(eager_out[0]["boxes"]) == len(export_out[0]["boxes"]) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 69703ab5817..630b4f1d629 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -81,5 +81,110 @@ def test_not_float_normalize(self): out = transform(image, targets) # noqa: F841 +class TestModelsDetectionUtilsExport: + """Export tests for detection utility components.""" + + @pytest.mark.parametrize("strict", [False, True]) + def test_box_coder_decode_export(self, strict): + """Exported BoxCoder.decode should match eager, using the same pattern + as test_box_linear_coder.""" + from torch.export import export + + class BoxCoderDecodeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.box_coder = _utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + def forward(self, rel_codes, boxes): + return self.box_coder.decode(rel_codes, [boxes]) + + torch.manual_seed(0) + boxes = torch.rand(10, 4) * 50 + boxes[:, 2:] += boxes[:, :2] + rel_codes = torch.randn(10, 4) + + model = BoxCoderDecodeModule() + with torch.no_grad(): + ep = export(model, (rel_codes, boxes), strict=strict) + + eager_out = model(rel_codes, boxes) + export_out = ep.module()(rel_codes, boxes) + torch.testing.assert_close(eager_out, export_out, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("strict", [False, True]) + def test_box_coder_decode_multi_class_export(self, strict): + """BoxCoder.decode with multi-class box regression (num_classes * 4 columns).""" + from torch.export import export + + num_classes = 5 + + class BoxCoderDecodeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.box_coder = _utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + def forward(self, rel_codes, boxes): + return self.box_coder.decode(rel_codes, [boxes]) + + torch.manual_seed(0) + boxes = torch.rand(10, 4) * 50 + boxes[:, 2:] += boxes[:, :2] + rel_codes = torch.randn(10, num_classes * 4) + + model = BoxCoderDecodeModule() + with torch.no_grad(): + ep = export(model, (rel_codes, boxes), strict=strict) + + eager_out = model(rel_codes, boxes) + export_out = ep.module()(rel_codes, boxes) + torch.testing.assert_close(eager_out, export_out, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("strict", [False, True]) + def test_transform_batch_images_export(self, strict): + """Exported batch_images should match eager, using the + same inputs as test_transform_copy_targets.""" + from torch.export import Dim, export + + transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) + + class BatchImagesModule(torch.nn.Module): + def __init__(self, t): + super().__init__() + self.size_divisible = t.size_divisible + + def forward(self, image): + return transform.batch_images([image], self.size_divisible) + + model = BatchImagesModule(transform) + model.eval() + + # batch_images pads to stride-32 multiples, creating // guards + # that require constrained dims (32*k aligned) + _h = Dim("_h", min=4, max=25) + _w = Dim("_w", min=4, max=25) + h = 32 * _h + w = 32 * _w + + x = torch.rand(3, 192, 256) # 32-aligned example + with torch.no_grad(): + ep = export( + model, + (x,), + dynamic_shapes={"image": {1: h, 2: w}}, + strict=strict, + ) + + # Same input + eager_out = model(x) + export_out = ep.module()(x) + torch.testing.assert_close(eager_out, export_out, atol=1e-6, rtol=1e-6) + + # Different 32-aligned size + x2 = torch.rand(3, 160, 320) + eager_out2 = model(x2) + export_out2 = ep.module()(x2) + torch.testing.assert_close(eager_out2, export_out2, atol=1e-6, rtol=1e-6) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index 11603df0c4c..3647da5e063 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -945,6 +945,92 @@ def test_batched_nms_implementations(self, seed): empty = torch.empty((0,), dtype=torch.int64) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) + @pytest.mark.parametrize("strict", [False, True]) + @pytest.mark.opcheck_only_one() + def test_nms_export(self, strict): + """Exported nms should match eager.""" + from torch.export import export + + class NMSModule(nn.Module): + def forward(self, boxes, scores, iou_threshold): + return ops.nms(boxes, scores, iou_threshold) + + torch.random.manual_seed(0) + boxes, scores = self._create_tensors_with_iou(100, 0.5) + model = NMSModule() + + with torch.no_grad(): + ep = export(model, (boxes, scores, 0.5), strict=strict) + + eager_out = model(boxes, scores, 0.5) + export_out = ep.module()(boxes, scores, 0.5) + torch.testing.assert_close(eager_out, export_out) + + @pytest.mark.parametrize("strict", [False, True]) + @pytest.mark.opcheck_only_one() + def test_batched_nms_export(self, strict): + """Exported batched_nms should match eager using the same inputs as + test_batched_nms_implementations.""" + from torch.export import export + + class BatchedNMSModule(nn.Module): + def forward(self, boxes, scores, idxs, iou_threshold): + return ops.batched_nms(boxes, scores, idxs, iou_threshold) + + torch.random.manual_seed(0) + num_boxes = 100 + iou_threshold = 0.9 + boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) + scores = torch.rand(num_boxes) + idxs = torch.randint(0, 4, size=(num_boxes,)) + + model = BatchedNMSModule() + with torch.no_grad(): + ep = export(model, (boxes, scores, idxs, iou_threshold), strict=strict) + + eager_out = model(boxes, scores, idxs, iou_threshold) + export_out = ep.module()(boxes, scores, idxs, iou_threshold) + torch.testing.assert_close(eager_out, export_out) + + +class TestMultiScaleRoIAlignExport: + """Export tests for MultiScaleRoIAlign, following the same input setup as test_onnx.py.""" + + @pytest.mark.parametrize("strict", [False, True]) + def test_multiscale_roi_align_export(self, strict): + """Exported MultiScaleRoIAlign should match eager.""" + from collections import OrderedDict + + from torch.export import export + + class TransformModule(nn.Module): + def __init__(self): + super().__init__() + self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2) + self.image_sizes = [(512, 512)] + + def forward(self, feat1, feat2, boxes): + x = OrderedDict([("feat1", feat1), ("feat2", feat2)]) + return self.model(x, [boxes], self.image_sizes) + + torch.random.manual_seed(0) + feat1 = torch.rand(1, 5, 64, 64) + feat2 = torch.rand(1, 5, 16, 16) + boxes = torch.rand(6, 4) * 256 + boxes[:, 2:] += boxes[:, :2] + + model = TransformModule() + # Pre-initialize scales + with torch.no_grad(): + _ = model(feat1, feat2, boxes) + + with torch.no_grad(): + ep = export(model, (feat1, feat2, boxes), strict=strict) + + eager_out = model(feat1, feat2, boxes) + export_out = ep.module()(feat1, feat2, boxes) + torch.testing.assert_close(eager_out, export_out, atol=1e-5, rtol=1e-5) + optests.generate_opcheck_tests( testcase=TestNMS, diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f75bfb77a7f..4a0f736b874 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -170,7 +170,7 @@ def meta_nms(dets, scores, iou_threshold): lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", ) ctx = torch._custom_ops.get_ctx() - num_to_keep = ctx.create_unbacked_symint() + num_to_keep = ctx.new_dynamic_size(min=0) return dets.new_empty(num_to_keep, dtype=torch.long) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 805c05a92ff..5946cc97c28 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -173,11 +173,20 @@ def decode(self, rel_codes: Tensor, boxes: list[Tensor]) -> Tensor: box_sum = 0 for val in boxes_per_image: box_sum += val - if box_sum > 0: - rel_codes = rel_codes.reshape(box_sum, -1) - pred_boxes = self.decode_single(rel_codes, concat_boxes) - if box_sum > 0: - pred_boxes = pred_boxes.reshape(box_sum, -1, 4) + ncols = rel_codes.shape[-1] + # JIT path uses reshape with -1, which is ambiguous when box_sum is 0. + # Non-JIT path uses explicit dims so reshape works for all box_sum + # (including 0, producing correctly-shaped empty tensors). + if torch.jit.is_scripting(): + if box_sum > 0: + rel_codes = rel_codes.reshape(box_sum, -1) + pred_boxes = self.decode_single(rel_codes, concat_boxes) + if box_sum > 0: + pred_boxes = pred_boxes.reshape(box_sum, -1, 4) + else: + rel_codes = rel_codes.reshape(box_sum, ncols) + pred_boxes = self.decode_single(rel_codes, concat_boxes) + pred_boxes = pred_boxes.reshape(box_sum, ncols // 4, 4) return pred_boxes def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index ac54873dee8..9bb05851ae1 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,4 +1,3 @@ -import math from typing import Any, Optional import torch @@ -182,9 +181,9 @@ def resize( target: Optional[dict[str, Tensor]] = None, ) -> tuple[Tensor, Optional[dict[str, Tensor]]]: h, w = image.shape[-2:] + if self._skip_resize: + return image, target if self.training: - if self._skip_resize: - return image, target size = self.torch_choice(self.min_size) else: size = self.min_size[-1] @@ -241,18 +240,16 @@ def batch_images(self, images: list[Tensor], size_divisible: int = 32) -> Tensor return self._onnx_batch_images(images, size_divisible) max_size = self.max_by_axis([list(img.shape) for img in images]) - stride = float(size_divisible) + stride = size_divisible max_size = list(max_size) - max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) - max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) - - batch_shape = [len(images)] + max_size - batched_imgs = images[0].new_full(batch_shape, 0) - for i in range(batched_imgs.shape[0]): - img = images[i] - batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + max_size[1] = (max_size[1] + stride - 1) // stride * stride + max_size[2] = (max_size[2] + stride - 1) // stride * stride - return batched_imgs + padded_imgs = [] + for img in images: + padding = [0, max_size[2] - img.shape[2], 0, max_size[1] - img.shape[1]] + padded_imgs.append(torch.nn.functional.pad(img, padding)) + return torch.stack(padded_imgs) def postprocess( self, diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a089af2c4ad..e2ba0b7f9b7 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -1,6 +1,7 @@ import torch import torchvision from torch import Tensor +from torch.fx.experimental.symbolic_shapes import guard_or_false from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once @@ -77,10 +78,21 @@ def batched_nms( # Benchmarks that drove the following thresholds are at # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 # and https://github.com/pytorch/vision/pull/8925 - if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) and not torchvision._is_tracing(): - return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) - else: + if torch.jit.is_scripting(): + # _is_tracing() is always False during scripting, so omitted here + if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000): + return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) + # In eager mode, use the original performance-based dispatch. + # In export (is_compiling), always use coordinate_trick to avoid + # data-dependent branching on boxes.numel(). + if ( + not torch.compiler.is_compiling() + and boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) + and not torchvision._is_tracing() + ): + return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) + return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) @torch.jit._script_if_tracing @@ -94,9 +106,22 @@ def _batched_nms_coordinate_trick( # we add an offset to all the boxes. The offset is dependent # only on the class idx, and is large enough so that boxes # from different classes do not overlap - if boxes.numel() == 0: - return torch.empty((0,), dtype=torch.int64, device=boxes.device) - max_coordinate = boxes.max() + if torch.jit.is_scripting(): + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + max_coordinate = boxes.max() + else: + if guard_or_false(boxes.numel() == 0): + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + if torch.compiler.is_compiling(): + # Concat a zero sentinel so .max() is safe when boxes is empty. + # This only affects max_coordinate; NMS still operates on the + # original (possibly empty) boxes and correctly returns empty. + max_coordinate = torch.cat( + [boxes.reshape(-1), torch.zeros(1, dtype=boxes.dtype, device=boxes.device)] + ).max() + else: + max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) boxes_for_nms = boxes + offsets[:, None] keep = nms(boxes_for_nms, scores, iou_threshold) diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index f887f6aee33..39e0da56cf3 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -182,7 +182,7 @@ def _multiscale_roi_align( levels = mapper(boxes) - num_rois = len(rois) + num_rois = rois.shape[0] num_channels = x_filtered[0].shape[1] dtype, device = x_filtered[0].dtype, x_filtered[0].device