Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def parse_args_executorch(parser):
required=False,
help="Device to run the model on. Options: cpu, cuda, mps. Default: cpu.",
)
required_group.add_argument(
"--image_size",
type=int,
required=False,
help="Image size for object detection models. Required for object-detection task.",
)


class ExecuTorchExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -263,6 +269,8 @@ def run(self):
kwargs["dtype"] = self.args.dtype
if hasattr(self.args, "device") and self.args.device:
kwargs["device"] = self.args.device
if hasattr(self.args, "image_size") and self.args.image_size:
kwargs["image_size"] = self.args.image_size

main_export(
model_name_or_path=self.args.model,
Expand Down
2 changes: 2 additions & 0 deletions optimum/executorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"ExecuTorchModelForSeq2SeqLM",
"ExecuTorchModelForSpeechSeq2Seq",
"ExecuTorchModelForMultiModalToText",
"ExecuTorchModelForObjectDetection",
],
}

Expand All @@ -34,6 +35,7 @@
ExecuTorchModelForImageClassification,
ExecuTorchModelForMaskedLM,
ExecuTorchModelForMultiModalToText,
ExecuTorchModelForObjectDetection,
ExecuTorchModelForSeq2SeqLM,
ExecuTorchModelForSpeechSeq2Seq,
)
Expand Down
62 changes: 61 additions & 1 deletion optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from huggingface_hub import hf_hub_download, is_offline_mode
Expand All @@ -31,6 +31,7 @@
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForObjectDetection,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
PreTrainedTokenizer,
Expand Down Expand Up @@ -929,6 +930,65 @@ def generate(self):
raise NotImplementedError


class ExecuTorchModelForObjectDetection(ExecuTorchModelBase):
"""
ExecuTorch model with an object detection head for inference using the ExecuTorch Runtime.

This class provides an interface for loading, running, and generating outputs from a vision transformer model
optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models
compatible with ExecuTorch runtime.

Attributes:
auto_model_class (`Type`):
Associated Transformers class, `AutoModelForObjectDetection`.
model (`ExecuTorchModule`):
The loaded ExecuTorch model.
"""

auto_model_class = AutoModelForObjectDetection

def __init__(
self,
models: Dict[str, "ExecuTorchModule"],
config: "PretrainedConfig",
):
super().__init__(models, config)
if not hasattr(self, "model"):
raise AttributeError("Expected attribute 'model' not found in the instance.")
metadata = self.model.method_names()

# Reconstruct id2label/label2id dicts from lists
if "get_label_ids" in metadata and "get_label_names" in metadata:
label_ids = self.model.run_method("get_label_ids")
label_names = self.model.run_method("get_label_names")
self.id2label = dict(zip(label_ids, label_names))
self.label2id = {v: k for k, v in self.id2label.items()}
if "image_size" in metadata:
self.image_size = self.model.run_method("image_size")[0]
if "num_channels" in metadata:
self.num_channels = self.model.run_method("num_channels")[0]
logging.debug(f"Load all static methods: {metadata}")

def forward(
self,
pixel_values: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the model.

Args:
pixel_values (`torch.Tensor`): Tensor representing an image input to the model.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Logits and predicted bounding boxes from the model.
"""
outputs = self.model.forward((pixel_values,))
return outputs[0], outputs[1] # logits, pred_boxes

def generate(self):
raise NotImplementedError


class ExecuTorchModelForImageClassification(ExecuTorchModelBase):
"""
ExecuTorch model with an image classification head for inference using the ExecuTorch Runtime.
Expand Down
80 changes: 80 additions & 0 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,86 @@ def export(
return {"model": exported_program}


class ObjectDetectionExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a object detection model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""

def __init__(self, model, image_size, num_channels=None):
super().__init__()
self.model = model
self.config = model.config
self.image_size = image_size

# Convert id2label dict into two lists to store properly in pte
id2label = getattr(model.config, "id2label", None)
if id2label:
label_ids = list(id2label.keys())
label_names = list(id2label.values())
else:
label_ids = []
label_names = []

# resolve num_channels
num_channels_from_config = self._get_num_channels_from_config()
if num_channels is not None:
self.num_channels = num_channels
elif num_channels_from_config is not None:
self.num_channels = num_channels_from_config
else:
# if nothing else, try 3 (RGB)
self.num_channels = 3
self.metadata = save_config_to_constant_methods(
model.config,
getattr(model, "generation_config", None),
get_label_ids=label_ids,
get_label_names=label_names,
image_size=self.image_size,
num_channels=self.num_channels,
)

def _get_num_channels_from_config(self) -> int | None:
"""try various config options to get num_channels, and return None if not found."""
# unfortunately none of the HF object detection models have consistency in how num_channels is defined
if hasattr(self.config, "num_channels"):
return self.config.num_channels
if hasattr(self.config, "backbone_config") and hasattr(self.config.backbone_config, "num_channels"):
return self.config.backbone_config.num_channels
if hasattr(self.config, "backbone_config") and hasattr(self.config.backbone_config, "in_chans"):
return self.config.backbone_config.in_chans
if hasattr(self.config, "backbone") and hasattr(self.config.backbone, "num_channels"):
return self.config.backbone.num_channels
if hasattr(self.config, "backbone") and hasattr(self.config.backbone, "in_chans"):
return self.config.backbone.in_chans
if hasattr(self.config, "in_chans"):
return self.config.in_chans
return None

def forward(self, pixel_values):
return self.model(pixel_values=pixel_values)

def export(self, pixel_values=None) -> Dict[str, ExportedProgram]:
if pixel_values is None:
batch_size = 1
num_channels = self.num_channels
height = self.image_size
width = self.image_size
pixel_values = torch.rand(
batch_size, num_channels, height, width, dtype=self.model.dtype, device=self.model.device
)

with torch.no_grad():
return {
"model": torch.export.export(
self.model,
args=(),
kwargs={"pixel_values": pixel_values},
strict=False,
)
}


class VisionEncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a vision encoder-only model exportable with `torch.export`.
Expand Down
46 changes: 46 additions & 0 deletions optimum/exporters/executorch/tasks/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import AutoModelForObjectDetection

from ..integrations import ObjectDetectionExportableModule
from ..task_registry import register_task


# NOTE: It’s important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py.
# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier.
@register_task("object-detection")
def load_object_detection_model(model_name_or_path: str, **kwargs) -> ObjectDetectionExportableModule:
"""
Loads a vision model for object detection and registers it under the task
'object-detection' using Hugging Face's `AutoModelForImageClassification`.

Args:
model_name_or_path (str):
Model ID on huggingface.co or path on disk to the model repository to export. For example:
`model_name_or_path="google/vit-base-patch16-224"` or `mode_name_or_path="/path/to/model_folder`
**kwargs:
Additional configuration options for the model.

Returns:
ObjectDetectionExportableModule:
An instance of `ObjectDetectionExportableModule` for exporting and lowering to ExecuTorch.
"""

image_size = kwargs.pop("image_size", None)
if image_size is None:
raise ValueError("image_size is a required argument for object-detection task")
num_channels = kwargs.pop("num_channels", None)
eager_model = AutoModelForObjectDetection.from_pretrained(model_name_or_path, **kwargs).to("cpu").eval()
return ObjectDetectionExportableModule(eager_model, image_size, num_channels)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dev = [
"tiktoken",
"black~=23.1",
"ruff==0.4.4",
"timm",
]

[project.urls]
Expand Down
88 changes: 88 additions & 0 deletions tests/models/test_modeling_detr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import tempfile
import unittest

import pytest
import torch
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
from transformers import AutoConfig, AutoModelForObjectDetection
from transformers.testing_utils import slow

from optimum.executorch import ExecuTorchModelForObjectDetection

from ..utils import check_close_recursively


class ExecuTorchModelIntegrationTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@slow
@pytest.mark.run_slow
def test_detr_export_to_executorch(self):
model_id = "facebook/detr-resnet-50" # note: requires timm
task = "object-detection"
recipe = "xnnpack"
with tempfile.TemporaryDirectory() as tempdir:
subprocess.run(
f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --image_size {640} --output_dir {tempdir}/executorch",
shell=True,
check=True,
)
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))

def _helper_detr_object_detection(self, recipe: str, image_size: int):
model_id = "facebook/detr-resnet-50" # note: requires timm

config = AutoConfig.from_pretrained(model_id)
batch_size = 1
num_channels = config.num_channels
height = image_size
width = image_size
pixel_values = torch.rand(batch_size, num_channels, height, width)

# Test fetching and lowering the model to ExecuTorch
et_model = ExecuTorchModelForObjectDetection.from_pretrained(
model_id=model_id, recipe=recipe, image_size=image_size
)
self.assertIsInstance(et_model, ExecuTorchModelForObjectDetection)
self.assertIsInstance(et_model.model, ExecuTorchModule)
self.assertIsInstance(et_model.id2label, dict)
self.assertEqual(et_model.image_size, image_size)
self.assertEqual(et_model.num_channels, num_channels)

eager_model = AutoModelForObjectDetection.from_pretrained(model_id).eval().to("cpu")
with torch.no_grad():
eager_output = eager_model(pixel_values)
et_logits, et_pred_boxes = et_model.forward(pixel_values)

# Compare with eager outputs
self.assertTrue(check_close_recursively(eager_output.logits, et_logits))
self.assertTrue(check_close_recursively(eager_output.pred_boxes, et_pred_boxes))

@slow
@pytest.mark.run_slow
def test_detr_object_detection(self):
self._helper_detr_object_detection(recipe="xnnpack", image_size=640)

@slow
@pytest.mark.run_slow
@pytest.mark.portable
def test_detr_object_detection_portable(self):
self._helper_detr_object_detection(recipe="portable", image_size=640)