diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 64ef23b5..e26071a0 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -52,11 +52,17 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies for ExecuTorch run: | + # Clean up cache to save space + pip cache purge || true + rm -rf ~/.cache/huggingface/hub/* || true + if [ "${{ matrix.executorch-version }}" == "nightly" ]; then python install_dev.py else - pip install '.[dev]' - pip install executorch==${{ matrix.executorch-version }} + # Use CPU-only torch to avoid CUDA dependencies (saves ~5GB) + pip install --no-cache-dir '.[dev]' \ + --extra-index-url https://download.pytorch.org/whl/cpu + pip install --no-cache-dir executorch==${{ matrix.executorch-version }} fi pip list - name: Run tests diff --git a/install_dev.py b/install_dev.py index 7ad9fa13..e0a5a697 100644 --- a/install_dev.py +++ b/install_dev.py @@ -5,7 +5,7 @@ def install_torch_nightly_deps(): """Install torch related dependencies from pinned nightly""" - EXECUTORCH_NIGHTLY_VERSION = "dev20251003" + EXECUTORCH_NIGHTLY_VERSION = "dev20251104" TORCHAO_NIGHTLY_VERSION = "dev20251104" # Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/torch_pin.py#L2 TORCH_NIGHTLY_VERSION = "dev20251104" @@ -15,6 +15,7 @@ def install_torch_nightly_deps(): "-m", "pip", "install", + "--no-cache-dir", # Prevent cached CUDA packages f"executorch==1.1.0.{EXECUTORCH_NIGHTLY_VERSION}", f"torch==2.10.0.{TORCH_NIGHTLY_VERSION}", f"torchvision==0.25.0.{TORCH_NIGHTLY_VERSION}", @@ -34,7 +35,7 @@ def install_dep_from_source(): "-m", "pip", "install", - "git+https://github.com/huggingface/transformers@91393fe4cc3266a05bc0d129e34ff5f761bb46e2#egg=transformers", # 4.56.1 + "git+https://github.com/huggingface/transformers@cbc6716945cff1d8e124d344ba0150e6e27f8b6e#egg=transformers", # v5.0.0rc0 ] ) subprocess.check_call( @@ -58,13 +59,13 @@ def main(): ) args = parser.parse_args() - # Install package with dev extras - subprocess.check_call([sys.executable, "-m", "pip", "install", ".[dev]"]) - - # Install nightly dependencies + # Install nightly torch dependencies FIRST to avoid pulling CUDA versions if not args.skip_override_torch: install_torch_nightly_deps() + # Install package with dev extras + subprocess.check_call([sys.executable, "-m", "pip", "install", ".[dev]"]) + # Install source dependencies install_dep_from_source() diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 729f6c7f..9b7f2824 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -17,7 +17,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from ...exporters import TasksManager +from transformers.pipelines import get_supported_tasks + from ..base import BaseOptimumCLICommand, CommandInfo @@ -46,7 +47,7 @@ def parse_args_executorch(parser): default="text-generation", help=( "The task to export the model for. Available tasks depend on the model, but are among:" - f" {str(TasksManager.get_all_tasks())}." + f" {str(get_supported_tasks())}." ), ) required_group.add_argument( diff --git a/optimum/commands/register/register_export.py b/optimum/commands/register/register_export.py index 3959de63..bd85f5ec 100644 --- a/optimum/commands/register/register_export.py +++ b/optimum/commands/register/register_export.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..export import ExportCommand -from ..export.executorch import ExecuTorchExportCommand +from optimum.commands.export.base import ExportCommand +from optimum.commands.export.executorch import ExecuTorchExportCommand REGISTER_COMMANDS = [(ExecuTorchExportCommand, ExportCommand)] diff --git a/optimum/executorch/attentions/custom_kv_cache.py b/optimum/executorch/attentions/custom_kv_cache.py index c10a3d2d..cf0d4ad9 100644 --- a/optimum/executorch/attentions/custom_kv_cache.py +++ b/optimum/executorch/attentions/custom_kv_cache.py @@ -45,8 +45,8 @@ def __init__( device=device, dtype=dtype, ) - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + num_heads = getattr(config, "num_key_value_heads", None) or config.num_attention_heads self.early_initialization( batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device ) diff --git a/optimum/executorch/attentions/custom_sdpa.py b/optimum/executorch/attentions/custom_sdpa.py index 88b6e2d8..f857cd12 100644 --- a/optimum/executorch/attentions/custom_sdpa.py +++ b/optimum/executorch/attentions/custom_sdpa.py @@ -18,12 +18,59 @@ from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa +def sdpa_mask_passthrough( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Optional[Callable] = None, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + allow_torch_fix: bool = True, + **kwargs, +) -> Optional[torch.Tensor]: + """ + Pass-through for attention mask creation since it is never used: + - For regular attention, the custom sdpa op in causal mode creates its own attention mask + - For sliding window attention, the attention mask from the attention mask API is ditched and re-created during the attention API since it needs to know about cache internals + + Additionally, there were some vmap export issues with sliding window attention mask creation in Transformers. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + + """ + return None + + def custom_sdpa_with_start_pos_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + position_ids: Optional[torch.Tensor] = None, scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, @@ -56,10 +103,10 @@ def custom_sdpa_with_start_pos_forward( # Calculate the input pos from attention mask. # Branch out for float vs bool mask # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) - first_row_mask = attention_mask[0, :] - # [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3 - start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1 + assert ( + position_ids is not None + ), "position_ids must be provided to find start position for causal attention" + start_pos = position_ids[0][0].item() else: start_pos = 0 @@ -95,6 +142,7 @@ def _custom_sdpa_for_ring_kv_cache( key: torch.Tensor, value: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + position_ids: Optional[torch.Tensor] = None, scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, @@ -122,6 +170,7 @@ def _custom_sdpa_for_ring_kv_cache( key, value, attention_mask, + position_ids, scaling, softcap, head_mask, @@ -134,6 +183,7 @@ def _custom_sdpa_for_ring_kv_cache( key, value, attention_mask, + position_ids, scaling, softcap, head_mask, diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index f48d8018..41faa036 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -34,9 +34,9 @@ AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, - add_start_docstrings, ) from transformers.configuration_utils import PretrainedConfig +from transformers.pipelines import get_task from transformers.processing_utils import ProcessorMixin from transformers.utils import is_offline_mode @@ -46,13 +46,11 @@ ) from executorch.kernels import quantized # noqa -from ..exporters import TasksManager from ..exporters.executorch import main_export from ..exporters.executorch.utils import ( process_conversation_inputs, verify_eos_tokens_in_pretrained_tokenizer, ) -from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel from ..utils.file_utils import find_files_matching_pattern from .stats import Stats @@ -63,7 +61,7 @@ logger = logging.getLogger(__name__) -class ExecuTorchModelBase(OptimizedModel, ABC): +class ExecuTorchModelBase(ABC): """ ExecuTorch model for inference using the ExecuTorch Runtime. @@ -99,8 +97,6 @@ def __init__( models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig", ): - super().__init__(model=None, config=config) - if self.__class__.auto_model_class is None: raise ValueError( f"Class {self.__class__.__name__} must set auto_model_class. " @@ -268,6 +264,7 @@ def _export( cls, model_id: str, recipe: str, + task: Optional[str] = None, config: Optional[PretrainedConfig] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, @@ -278,9 +275,8 @@ def _export( local_files_only: bool = False, **kwargs, ) -> Dict[str, "ExecuTorchModule"]: - task = kwargs.pop("task", None) - inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task - logging.info(f"Inferred task from model class: {inferred_task}") + inferred_task = get_task(model_id) if not task else task + logging.info(f"Using task: {inferred_task}") save_dir = TemporaryDirectory(prefix="executorch_export_") save_dir_path = Path(save_dir.name) @@ -316,7 +312,6 @@ def _save_pretrained(self, save_directory): raise NotImplementedError @classmethod - @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING) def from_pretrained( cls, model_id: Union[str, Path], @@ -1322,9 +1317,9 @@ def generate( def text_generation( self, - processor: ProcessorMixin, - tokenizer: "PreTrainedTokenizer", input_conversation: List[Dict], + processor: Optional[ProcessorMixin] = None, + tokenizer: Optional["PreTrainedTokenizer"] = None, echo: bool = True, max_seq_len: Optional[int] = None, ): @@ -1362,9 +1357,9 @@ def text_generation( self.stats.on_inference_start() inputs = process_conversation_inputs( + input_conversation, processor, tokenizer, - input_conversation, ) self.stats.on_token_encode_end() diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index d64dc527..49fb5635 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -19,8 +19,7 @@ from pathlib import Path from typing import Union -from transformers.integrations.executorch import sdpa_mask_without_vmap -from transformers.masking_utils import AttentionMaskInterface +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, AttentionMaskInterface from transformers.modeling_utils import AttentionInterface from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward @@ -29,7 +28,7 @@ AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward) -AttentionMaskInterface.register("custom_sdpa", sdpa_mask_without_vmap) +AttentionMaskInterface.register("custom_sdpa", ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]) def export_to_executorch( diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 1312b9e0..53a424ae 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from typing import Dict @@ -31,24 +32,43 @@ ) from transformers.integrations.executorch import ( TorchExportableModuleForDecoderOnlyLM, - sdpa_mask_without_vmap, ) from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface -from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache +from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods +SUPPORTED_VISION_ENCODER_INPUTS = ["pixel_values", "image_sizes"] + + class VisionExportableModule(torch.nn.Module): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model def prepare_export_inputs(self): + # Check if the "get_image_features" function has any args that were are unprepared for. + required_args = [] + sig = inspect.signature(self.model.get_image_features) + for name, param in sig.parameters.items(): + if param.default is inspect._empty and name != "kwargs": + required_args.append(name) + + if "pixel_values" not in required_args: + raise AttributeError( + "`pixel_values` is not in in the `get_image_features()` API. This is unexpected - please investigate." + ) + unsupported_args = set(required_args) - set(SUPPORTED_VISION_ENCODER_INPUTS) + if unsupported_args: + raise AttributeError(f"The following args are not yet supported, please implement: {unsupported_args}.") + # 1. Get export inputs + input_kwargs = {} model_id = self.model.config.name_or_path + config = AutoConfig.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) sample_conversation_with_image = [ { @@ -61,29 +81,42 @@ def prepare_export_inputs(self): ], }, ] - processed_inputs = processor.apply_chat_template( - sample_conversation_with_image, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ) + + if config.model_type == "mistral3": + from transformers import MistralCommonBackend + + tokenizer = MistralCommonBackend.from_pretrained(model_id) + processed_inputs = tokenizer.apply_chat_template( + sample_conversation_with_image, return_tensors="pt", return_dict=True + ) + else: + processed_inputs = processor.apply_chat_template( + sample_conversation_with_image, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) if "pixel_values" not in processed_inputs: raise ValueError( f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}" ) - export_inputs = processed_inputs["pixel_values"].to(dtype=self.model.dtype) + + input_kwargs["input_features"] = processed_inputs["pixel_values"].to(dtype=self.model.dtype) + if "image_sizes" in required_args: + input_kwargs["image_sizes"] = [processed_inputs["pixel_values"].shape[-2:]] # 2. Get export dynamic shapes dynamic_shapes = None # No batching for now. - return export_inputs, dynamic_shapes + return input_kwargs, dynamic_shapes def forward( self, input_features: torch.FloatTensor, + **kwargs, ): - image_embeds = self.model.get_image_features(input_features) + image_embeds = self.model.get_image_features(input_features, **kwargs) if isinstance(image_embeds, list): image_embeds = torch.stack(image_embeds) return image_embeds @@ -96,6 +129,7 @@ def __init__(self, model: torch.nn.Module): def prepare_export_inputs(self): # 1. Get export inputs + input_kwargs = {} model_id = self.model.config.name_or_path processor = AutoProcessor.from_pretrained(model_id) config = AutoConfig.from_pretrained(model_id) @@ -135,10 +169,11 @@ def prepare_export_inputs(self): raise ValueError( f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'input_features' key: {processed_inputs}" ) - export_inputs = processed_inputs["input_features"].to(dtype=self.model.dtype) + audio_features = processed_inputs["input_features"].to(dtype=self.model.dtype) # Make sure the export inputs has a batch size > 1 so that it doesn't 0/1 specialize. - if export_inputs.shape[0] == 1: - export_inputs = export_inputs.repeat(2, 1, 1) + if audio_features.shape[0] == 1: + audio_features = audio_features.repeat(2, 1, 1) + input_kwargs["input_features"] = audio_features # 2. Get export dynamic shapes # For certain models like Voxtral, each 30 seconds represent one batch. So theoretically this caps @@ -150,15 +185,16 @@ def prepare_export_inputs(self): }, } - return export_inputs, dynamic_shapes + return input_kwargs, dynamic_shapes def forward( self, input_features: torch.FloatTensor, + **kwargs, ): # TODO: remove on next Transformers pin bump. if hasattr(self.model, "get_audio_embeds"): - audio_embeds = self.model.get_audio_embeds(input_features) + audio_embeds = self.model.get_audio_embeds(input_features, **kwargs) else: audio_embeds = self.model.get_audio_features(input_features) return audio_embeds.unsqueeze(0) @@ -212,7 +248,7 @@ def __init__( additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "image_token_id") self.metadata = save_config_to_constant_methods( config=model.config.text_config, - generation_config=model.generation_config, + generation_config=getattr(model, "generation_config", None), processor_config=processor_config, get_max_seq_len=max_seq_len, **additional_metadata_kwargs, @@ -269,7 +305,7 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module): if self.use_custom_sdpa: if self.use_custom_kv_cache: AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) - AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" @@ -381,20 +417,20 @@ def export( raise ValueError( f"{self.model.config.name_or_path} has an unsupported modality that is not supported yet for Optimum - please file an issue." ) - input_features, dynamic_shapes = encoder.prepare_export_inputs() + input_kwargs, dynamic_shapes = encoder.prepare_export_inputs() logging.info( - f"Exporting {self.modality} encoder using input_features({input_features.shape}), dynamic_shapes={dynamic_shapes}" + f"Exporting {self.modality} encoder using input_features({input_kwargs['input_features'].shape}), dynamic_shapes={dynamic_shapes}" ) # Move inputs to the same device as the model - input_features = input_features.to(self.model.device) + for kwarg, value in input_kwargs.items(): + if torch.is_tensor(value): + input_kwargs[kwarg] = value.to(self.model.device) encoder_exported_program = torch.export.export( encoder, args=(), - kwargs={ - "input_features": input_features, - }, + kwargs=input_kwargs, dynamic_shapes=dynamic_shapes, strict=True, ) @@ -425,7 +461,7 @@ def __init__( self.disable_dynamic_shapes = disable_dynamic_shapes self.metadata = save_config_to_constant_methods( model.config, - model.generation_config, + generation_config=getattr(model, "generation_config", None), get_max_seq_len=max_seq_len, enable_dynamic_shape=not self.disable_dynamic_shapes, ) @@ -455,7 +491,7 @@ def _prepare_export_inputs(self): if not self.disable_dynamic_shapes and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache: # Prepare inputs with dynamic shapes - seq_length = 3 # Sequence length > 1 to avoid specialization issues + seq_length = 3 # Sequence length > 1 to avoid specialization issue example_input_ids = torch.zeros((1, seq_length), dtype=torch.long, device=self.model.device) example_cache_position = torch.arange(seq_length, dtype=torch.long, device=self.model.device) max_seq_len = self.metadata.get("get_max_seq_len") @@ -471,7 +507,6 @@ def _prepare_export_inputs(self): return example_input_ids, example_cache_position, dynamic_shapes, strict def _register_custom_attention(self, exportable_module: torch.nn.Module): - from transformers.integrations.executorch import sdpa_mask_without_vmap from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface @@ -479,7 +514,7 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module): if self.use_custom_kv_cache: _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) - AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap) + AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough) # Manually set the attention implementation to custom_sdpa_ring_kv_cache # This handles both regular sdpa and one for sliding window/local attention exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" @@ -554,7 +589,7 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None)) def forward(self, pixel_values): print(f"DEBUG: pixel_values: {pixel_values.shape}") @@ -593,7 +628,7 @@ def __init__(self, model): self.model = model self.config = model.config # Metadata to be recorded in the pte model file - self.metadata = save_config_to_constant_methods(model.config, model.generation_config) + self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None)) def forward(self, input_ids, attention_mask): return self.model(input_ids, attention_mask) diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index d48f8290..42559c68 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -181,7 +181,16 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): "device": device, }, ) - decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) + if hasattr(eager_model, "model"): + # For cases where the actual model containing decoder and encoders is in the top level "model" attribute. + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model.model) + setattr(eager_model, decoder_name, getattr(eager_model.model, decoder_name)) + if audio_encoder_name: + setattr(eager_model, audio_encoder_name, getattr(eager_model.model, audio_encoder_name)) + if vision_encoder_name: + setattr(eager_model, vision_encoder_name, getattr(eager_model.model, vision_encoder_name)) + else: + decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model) encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name # Need to do this since apparently when nested modules (e.g. model.language_model) access the .property diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index 8b929c9f..cd6df11b 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -21,7 +21,7 @@ import transformers from transformers import GenerationConfig, PretrainedConfig from transformers.processing_utils import ProcessorMixin -from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_tokenizers import TokenizersBackend def save_config_to_constant_methods( @@ -88,7 +88,7 @@ def apply_chat_template_with_fallback(processor, conversation, **kwargs): return processor.apply_chat_template(conversation) -def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: PreTrainedTokenizer) -> bool: +def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: TokenizersBackend) -> bool: """ Verifies that the model's EOS token IDs are present in the tokenizer's set of potential end-of-sequence tokens. @@ -135,9 +135,9 @@ def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenize def process_conversation_inputs( - processor: ProcessorMixin, - tokenizer: PreTrainedTokenizer, input_conversation: List[Dict[str, Any]], + processor: Optional[ProcessorMixin] = None, + tokenizer: Optional[TokenizersBackend] = None, ): """ Process input conversation for multimodal models. @@ -154,6 +154,12 @@ def process_conversation_inputs( Returns: Processed inputs ready for model consumption """ + if not tokenizer: + raise ValueError("Must provide tokenizer to process model inputs.") + if not processor: + # Some models don't use a processor, usually in this case the tokenizer does all of the work. + return tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True) + if isinstance(processor, transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor): import requests import torchaudio diff --git a/pyproject.toml b/pyproject.toml index d83f191b..6272c40c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,9 @@ classifiers = [ ] dependencies = [ - "optimum~=1.24", + "optimum~=2.0.0", "executorch>=1.0.0", - "transformers==4.56.1", + "transformers==5.0.0rc0", "pytorch-tokenizers>=1.0.1", "accelerate>=0.26.0", ] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4f30dc03..63017df1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -55,7 +55,7 @@ def test_export_cli_helps_no_raise(self): def test_load_cached_model_from_hub(self): model_id = "optimum-internal-testing/tiny-random-llama" - model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="xnnpack") + model = ExecuTorchModelForCausalLM.from_pretrained(model_id, task="text-generation", recipe="xnnpack") self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertTrue(hasattr(model, "model")) self.assertIsInstance(model.model, ExecuTorchModule) diff --git a/tests/models/test_modeling_gemma.py b/tests/models/test_modeling_gemma.py index ccb00024..0212fa16 100644 --- a/tests/models/test_modeling_gemma.py +++ b/tests/models/test_modeling_gemma.py @@ -79,6 +79,7 @@ def test_gemma_text_generation_with_custom_sdpa_8da4w_8we(self): kwargs = {"qlinear": "8da4w", "qembedding": "8w"} model = ExecuTorchModelForCausalLM.from_pretrained( model_id, + task="text-generation", recipe="xnnpack", attn_implementation="custom_sdpa", **kwargs, @@ -102,7 +103,7 @@ def test_gemma_text_generation_portable(self): # TODO: Switch to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed # model_id = "google/gemma-2b" model_id = "weqweasdas/RM-Gemma-2B" - model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="portable") + model = ExecuTorchModelForCausalLM.from_pretrained(model_id, task="text-generation", recipe="portable") self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertIsInstance(model.model, ExecuTorchModule) diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index 83dc72fa..fe545b07 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -331,9 +331,9 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): # Generate generated_text = model.text_generation( + input_conversation=conversation, processor=processor, tokenizer=tokenizer, - input_conversation=conversation, max_seq_len=64, ) logging.info(f"\nGenerated text:\n\t{generated_text}") diff --git a/tests/models/test_modeling_glm.py b/tests/models/test_modeling_glm.py index 8eccd059..18159f56 100644 --- a/tests/models/test_modeling_glm.py +++ b/tests/models/test_modeling_glm.py @@ -49,6 +49,7 @@ def test_glm_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self): tokenizer = AutoTokenizer.from_pretrained(model_id) model = ExecuTorchModelForCausalLM.from_pretrained( model_id, + task="text-generation", recipe="xnnpack", attn_implementation="custom_sdpa", use_custom_kv_cache=True, diff --git a/tests/models/test_modeling_granite_speech.py b/tests/models/test_modeling_granite_speech.py index 51854fa3..34d0b9e2 100644 --- a/tests/models/test_modeling_granite_speech.py +++ b/tests/models/test_modeling_granite_speech.py @@ -75,9 +75,9 @@ def test_granite_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8 self.assertIsInstance(model.model, ExecuTorchModule) generated_text = model.text_generation( + input_conversation=conversation, processor=processor, tokenizer=tokenizer, - input_conversation=conversation, max_seq_len=64, ) logging.info(f"\nGenerated text:\n\t{generated_text}") diff --git a/tests/models/test_modeling_ministral.py b/tests/models/test_modeling_ministral.py new file mode 100644 index 00000000..6a3c4652 --- /dev/null +++ b/tests/models/test_modeling_ministral.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import unittest + +import pytest +from transformers import MistralCommonBackend +from transformers.testing_utils import slow + +from optimum.executorch import ExecuTorchModelForMultiModalToText + +from ..utils import check_causal_lm_output_quality + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_ministral_3_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self): + model_id = "mistralai/Ministral-3-3B-Instruct-2512" + tokenizer = MistralCommonBackend.from_pretrained(model_id) + image_url = ( + "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438" + ) + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ] + tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True) + + model = ExecuTorchModelForMultiModalToText.from_pretrained( + model_id, + recipe="xnnpack", + task="multimodal-text-to-text", + use_custom_sdpa=True, + use_custom_kv_cache=True, + qlinear="8da4w", + qlinear_group_size=32, + qlinear_encoder="8da4w", + qlinear_encoder_group_size=32, + qembedding="8w", + qembedding_encoder="8w", + ) + + # Generate + generated_text = model.text_generation( + input_conversation=conversation, + tokenizer=tokenizer, + max_seq_len=64, + ) + logging.info(f"\nGenerated text:\n\t{generated_text}") + generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids + breakpoint() + + # Free memory before loading eager for quality check + del model + del tokenizer + gc.collect() + + self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) diff --git a/tests/models/test_modeling_qwen3_embedding.py b/tests/models/test_modeling_qwen3_embedding.py index 0146634f..43618444 100644 --- a/tests/models/test_modeling_qwen3_embedding.py +++ b/tests/models/test_modeling_qwen3_embedding.py @@ -45,6 +45,7 @@ def test_qwen3_embedding_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we tokenizer = AutoTokenizer.from_pretrained(model_id) model = ExecuTorchModelForCausalLM.from_pretrained( model_id, + task="text-generation", recipe="xnnpack", attn_implementation="custom_sdpa", use_custom_kv_cache=True, @@ -78,7 +79,7 @@ def test_qwen3_embedding_text_generation_portable(self): model_id = "Qwen/Qwen3-Embedding-0.6B" prompt = "Explain gravity" tokenizer = AutoTokenizer.from_pretrained(model_id) - model = ExecuTorchModelForCausalLM.from_pretrained(model_id, recipe="portable") + model = ExecuTorchModelForCausalLM.from_pretrained(model_id, task="text-generation", recipe="portable") self.assertIsInstance(model, ExecuTorchModelForCausalLM) self.assertIsInstance(model.model, ExecuTorchModule) generated_text = model.text_generation( diff --git a/tests/models/test_modeling_t5.py b/tests/models/test_modeling_t5.py index 3ee5c10d..e827e8b8 100644 --- a/tests/models/test_modeling_t5.py +++ b/tests/models/test_modeling_t5.py @@ -49,7 +49,7 @@ def test_t5_export_to_executorch(self): def _helper_t5_translation(self, recipe: str): model_id = "google/flan-t5-small" tokenizer = AutoTokenizer.from_pretrained(model_id) - model = ExecuTorchModelForSeq2SeqLM.from_pretrained(model_id, recipe=recipe) + model = ExecuTorchModelForSeq2SeqLM.from_pretrained(model_id, task="text2text-generation", recipe=recipe) input_text = "translate English to German: How old are you?" generated_text = model.text_generation( @@ -78,7 +78,7 @@ def test_t5_translation_portable(self): def _helper_t5_summarization(self, recipe: str): model_id = "google-t5/t5-small" tokenizer = AutoTokenizer.from_pretrained(model_id) - model = ExecuTorchModelForSeq2SeqLM.from_pretrained(model_id, recipe=recipe) + model = ExecuTorchModelForSeq2SeqLM.from_pretrained(model_id, task="text2text-generation", recipe=recipe) article = ( " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" @@ -110,7 +110,7 @@ def _helper_t5_summarization(self, recipe: str): tokenizer=tokenizer, prompt=article, ) - expected_text = 'a year later, she got married again in westchester county, new york. she was married to a different man, but only 18 days after that marriage. she is facing two criminal counts of "offering a false instrument"' + expected_text = 'a year later, she got married again in westchester county, new york . she was married to a different man, but only 18 days after that marriage . she is facing two criminal counts of "offering a false instrument"' logging.info(f"\nInput text:\n\t{article}\nGenerated text:\n\t{generated_text}") self.assertEqual(generated_text, expected_text) diff --git a/tests/models/test_modeling_voxtral.py b/tests/models/test_modeling_voxtral.py index 1f7f6b34..4e639e02 100644 --- a/tests/models/test_modeling_voxtral.py +++ b/tests/models/test_modeling_voxtral.py @@ -310,9 +310,9 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8 self.assertIsInstance(model.model, ExecuTorchModule) generated_text = model.text_generation( + input_conversation=conversation, processor=processor, tokenizer=tokenizer, - input_conversation=conversation, max_seq_len=64, ) logging.info(f"\nGenerated text:\n\t{generated_text}")