From e1c5699faa11781d635dd0dc58f62394fa626708 Mon Sep 17 00:00:00 2001 From: Benjamin Li Date: Sun, 25 Jan 2026 12:49:13 -0500 Subject: [PATCH 1/3] initial implementation of object detection --- optimum/commands/export/executorch.py | 8 ++ optimum/executorch/__init__.py | 2 + optimum/executorch/modeling.py | 51 ++++++++++- optimum/exporters/executorch/integrations.py | 63 ++++++++++++++ .../executorch/tasks/object_detection.py | 46 ++++++++++ pyproject.toml | 1 + tests/models/test_modeling_detr.py | 85 +++++++++++++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 optimum/exporters/executorch/tasks/object_detection.py create mode 100644 tests/models/test_modeling_detr.py diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 8df9ed0c..435a2f09 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -189,6 +189,12 @@ def parse_args_executorch(parser): required=False, help="Device to run the model on. Options: cpu, cuda, mps. Default: cpu.", ) + required_group.add_argument( + "--image_size", + type=int, + required=False, + help="Image size for object detection models. Required for object-detection task.", + ) class ExecuTorchExportCommand(BaseOptimumCLICommand): @@ -263,6 +269,8 @@ def run(self): kwargs["dtype"] = self.args.dtype if hasattr(self.args, "device") and self.args.device: kwargs["device"] = self.args.device + if hasattr(self.args, "image_size") and self.args.image_size: + kwargs["image_size"] = self.args.image_size main_export( model_name_or_path=self.args.model, diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index 794f0cce..5a62302e 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -25,6 +25,7 @@ "ExecuTorchModelForSeq2SeqLM", "ExecuTorchModelForSpeechSeq2Seq", "ExecuTorchModelForMultiModalToText", + "ExecuTorchModelForObjectDetection", ], } @@ -34,6 +35,7 @@ ExecuTorchModelForImageClassification, ExecuTorchModelForMaskedLM, ExecuTorchModelForMultiModalToText, + ExecuTorchModelForObjectDetection, ExecuTorchModelForSeq2SeqLM, ExecuTorchModelForSpeechSeq2Seq, ) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 8e7da4cc..e8a1c58d 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from huggingface_hub import hf_hub_download, is_offline_mode @@ -31,6 +31,7 @@ AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedLM, + AutoModelForObjectDetection, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, @@ -929,6 +930,54 @@ def generate(self): raise NotImplementedError +class ExecuTorchModelForObjectDetection(ExecuTorchModelBase): + """ + ExecuTorch model with an object detection head for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a vision transformer model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForObjectDetection`. + model (`ExecuTorchModule`): + The loaded ExecuTorch model. + """ + + auto_model_class = AutoModelForObjectDetection + + def __init__( + self, + models: Dict[str, "ExecuTorchModule"], + config: "PretrainedConfig", + ): + super().__init__(models, config) + if not hasattr(self, "model"): + raise AttributeError("Expected attribute 'model' not found in the instance.") + metadata = self.model.method_names() + logging.debug(f"Load all static methods: {metadata}") + + def forward( + self, + pixel_values: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the model. + + Args: + pixel_values (`torch.Tensor`): Tensor representing an image input to the model. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Logits and predicted bounding boxes from the model. + """ + outputs = self.model.forward((pixel_values,)) + return outputs[0], outputs[1] # logits, pred_boxes + + def generate(self): + raise NotImplementedError + + class ExecuTorchModelForImageClassification(ExecuTorchModelBase): """ ExecuTorch model with an image classification head for inference using the ExecuTorch Runtime. diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index ce7d6a47..c82b2d91 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -538,6 +538,69 @@ def export( return {"model": exported_program} +class ObjectDetectionExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a object detection model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model, image_size, num_channels=None): + super().__init__() + self.model = model + self.config = model.config + self.image_size = image_size + self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None)) + + num_channels_from_config = self._get_num_channels_from_config() + if num_channels is not None: + self.num_channels = num_channels + elif num_channels_from_config is not None: + self.num_channels = num_channels_from_config + else: + # if nothing else, try 3 (RGB) + self.num_channels = 3 + + def _get_num_channels_from_config(self) -> int | None: + """try various config options to get num_channels, and return None if not found.""" + # unfortunately none of the HF object detection models have consistency in how num_channels is defined + if hasattr(self.config, "num_channels"): + return self.config.num_channels + if hasattr(self.config, "backbone_config") and hasattr(self.config.backbone_config, "num_channels"): + return self.config.backbone_config.num_channels + if hasattr(self.config, "backbone_config") and hasattr(self.config.backbone_config, "in_chans"): + return self.config.backbone_config.in_chans + if hasattr(self.config, "backbone") and hasattr(self.config.backbone, "num_channels"): + return self.config.backbone.num_channels + if hasattr(self.config, "backbone") and hasattr(self.config.backbone, "in_chans"): + return self.config.backbone.in_chans + if hasattr(self.config, "in_chans"): + return self.config.in_chans + return None + + def forward(self, pixel_values): + return self.model(pixel_values=pixel_values) + + def export(self, pixel_values=None) -> Dict[str, ExportedProgram]: + if pixel_values is None: + batch_size = 1 + num_channels = self.num_channels + height = self.image_size + width = self.image_size + pixel_values = torch.rand( + batch_size, num_channels, height, width, dtype=self.model.dtype, device=self.model.device + ) + + with torch.no_grad(): + return { + "model": torch.export.export( + self.model, + args=(), + kwargs={"pixel_values": pixel_values}, + strict=False, + ) + } + + class VisionEncoderExportableModule(torch.nn.Module): """ A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. diff --git a/optimum/exporters/executorch/tasks/object_detection.py b/optimum/exporters/executorch/tasks/object_detection.py new file mode 100644 index 00000000..22e30684 --- /dev/null +++ b/optimum/exporters/executorch/tasks/object_detection.py @@ -0,0 +1,46 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoModelForObjectDetection + +from ..integrations import ObjectDetectionExportableModule +from ..task_registry import register_task + + +# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. +# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. +@register_task("object-detection") +def load_image_classification_model(model_name_or_path: str, **kwargs) -> ObjectDetectionExportableModule: + """ + Loads a vision model for object detection and registers it under the task + 'object-detection' using Hugging Face's `AutoModelForImageClassification`. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="google/vit-base-patch16-224"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model. + + Returns: + ObjectDetectionExportableModule: + An instance of `ObjectDetectionExportableModule` for exporting and lowering to ExecuTorch. + """ + + image_size = kwargs.pop("image_size", None) + if image_size is None: + raise ValueError("image_size is a required argument for object-detection task") + num_channels = kwargs.pop("num_channels", None) + eager_model = AutoModelForObjectDetection.from_pretrained(model_name_or_path, **kwargs).to("cpu").eval() + return ObjectDetectionExportableModule(eager_model, image_size, num_channels) diff --git a/pyproject.toml b/pyproject.toml index a35817bb..e2bcde05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dev = [ "tiktoken", "black~=23.1", "ruff==0.4.4", + "timm", ] [project.urls] diff --git a/tests/models/test_modeling_detr.py b/tests/models/test_modeling_detr.py new file mode 100644 index 00000000..fdb94117 --- /dev/null +++ b/tests/models/test_modeling_detr.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +import unittest + +import pytest +import torch +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoConfig, AutoModelForObjectDetection +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForObjectDetection + +from ..utils import check_close_recursively + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_detr_export_to_executorch(self): + model_id = "facebook/detr-resnet-50" # note: requires timm + task = "object-detection" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --image_size {640} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + def _helper_detr_object_detection(self, recipe: str, image_size: int): + model_id = "facebook/detr-resnet-50" # note: requires timm + + config = AutoConfig.from_pretrained(model_id) + batch_size = 1 + num_channels = config.num_channels + height = image_size + width = image_size + pixel_values = torch.rand(batch_size, num_channels, height, width) + + # Test fetching and lowering the model to ExecuTorch + et_model = ExecuTorchModelForObjectDetection.from_pretrained( + model_id=model_id, recipe=recipe, image_size=image_size + ) + self.assertIsInstance(et_model, ExecuTorchModelForObjectDetection) + self.assertIsInstance(et_model.model, ExecuTorchModule) + + eager_model = AutoModelForObjectDetection.from_pretrained(model_id).eval().to("cpu") + with torch.no_grad(): + eager_output = eager_model(pixel_values) + et_logits, et_pred_boxes = et_model.forward(pixel_values) + + # Compare with eager outputs + self.assertTrue(check_close_recursively(eager_output.logits, et_logits)) + self.assertTrue(check_close_recursively(eager_output.pred_boxes, et_pred_boxes)) + + @slow + @pytest.mark.run_slow + def test_detr_object_detection(self): + self._helper_detr_object_detection(recipe="xnnpack", image_size=640) + + @slow + @pytest.mark.run_slow + @pytest.mark.portable + def test_detr_object_detection_portable(self): + self._helper_detr_object_detection(recipe="portable", image_size=640) From ad3848517e7a0c51dc32e6fbbbf97b6056c22f3f Mon Sep 17 00:00:00 2001 From: Benjamin Li Date: Sun, 25 Jan 2026 23:13:37 -0500 Subject: [PATCH 2/3] add id2label, num_channels, image_size to model --- optimum/executorch/modeling.py | 11 +++++++++++ optimum/exporters/executorch/integrations.py | 20 ++++++++++++++++++-- tests/models/test_modeling_detr.py | 3 +++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index e8a1c58d..f119d94a 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -956,6 +956,17 @@ def __init__( if not hasattr(self, "model"): raise AttributeError("Expected attribute 'model' not found in the instance.") metadata = self.model.method_names() + + # Reconstruct id2label/label2id dicts from lists + if "get_label_ids" in metadata and "get_label_names" in metadata: + label_ids = self.model.run_method("get_label_ids") + label_names = self.model.run_method("get_label_names") + self.id2label = dict(zip(label_ids, label_names)) + self.label2id = {v: k for k, v in self.id2label.items()} + if "image_size" in metadata: + self.image_size = self.model.run_method("image_size")[0] + if "num_channels" in metadata: + self.num_channels = self.model.run_method("num_channels")[0] logging.debug(f"Load all static methods: {metadata}") def forward( diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index c82b2d91..5a7fb4e5 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Dict @@ -549,8 +548,17 @@ def __init__(self, model, image_size, num_channels=None): self.model = model self.config = model.config self.image_size = image_size - self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None)) + # Convert id2label dict into two lists to store properly in pte + id2label = getattr(model.config, "id2label", None) + if id2label: + label_ids = list(id2label.keys()) + label_names = list(id2label.values()) + else: + label_ids = [] + label_names = [] + + # resolve num_channels num_channels_from_config = self._get_num_channels_from_config() if num_channels is not None: self.num_channels = num_channels @@ -559,6 +567,14 @@ def __init__(self, model, image_size, num_channels=None): else: # if nothing else, try 3 (RGB) self.num_channels = 3 + self.metadata = save_config_to_constant_methods( + model.config, + getattr(model, "generation_config", None), + get_label_ids=label_ids, + get_label_names=label_names, + image_size=self.image_size, + num_channels=self.num_channels, + ) def _get_num_channels_from_config(self) -> int | None: """try various config options to get num_channels, and return None if not found.""" diff --git a/tests/models/test_modeling_detr.py b/tests/models/test_modeling_detr.py index fdb94117..3bc325e9 100644 --- a/tests/models/test_modeling_detr.py +++ b/tests/models/test_modeling_detr.py @@ -63,6 +63,9 @@ def _helper_detr_object_detection(self, recipe: str, image_size: int): ) self.assertIsInstance(et_model, ExecuTorchModelForObjectDetection) self.assertIsInstance(et_model.model, ExecuTorchModule) + self.assertIsInstance(et_model.id2label, dict) + self.assertEqual(et_model.image_size, image_size) + self.assertEqual(et_model.num_channels, num_channels) eager_model = AutoModelForObjectDetection.from_pretrained(model_id).eval().to("cpu") with torch.no_grad(): From e823121681b2896f53f77de0e7c4d226a2ca1567 Mon Sep 17 00:00:00 2001 From: Benjamin Li Date: Sun, 25 Jan 2026 23:23:54 -0500 Subject: [PATCH 3/3] change function name --- optimum/exporters/executorch/integrations.py | 1 + optimum/exporters/executorch/tasks/object_detection.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 5a7fb4e5..27239349 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging from typing import Dict diff --git a/optimum/exporters/executorch/tasks/object_detection.py b/optimum/exporters/executorch/tasks/object_detection.py index 22e30684..66e46d80 100644 --- a/optimum/exporters/executorch/tasks/object_detection.py +++ b/optimum/exporters/executorch/tasks/object_detection.py @@ -21,7 +21,7 @@ # NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py. # This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier. @register_task("object-detection") -def load_image_classification_model(model_name_or_path: str, **kwargs) -> ObjectDetectionExportableModule: +def load_object_detection_model(model_name_or_path: str, **kwargs) -> ObjectDetectionExportableModule: """ Loads a vision model for object detection and registers it under the task 'object-detection' using Hugging Face's `AutoModelForImageClassification`.