From 718c20848d3ae82a7b84e05d698ab75a14d37955 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 26 Jan 2026 16:16:54 -0800 Subject: [PATCH] Export sampler for Seq2Seq models For cuda backend we want to keep most if not all calculation on device. Assuming most of the ASR applications are doing a greedy argmax sampling, we are exporting and lowering sampling into a method of ExecuTorch model. This way the runner can choose to run it. --- optimum/exporters/executorch/integrations.py | 25 ++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index ce7d6a4..28d306f 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -763,6 +763,15 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): return logits +class ArgmaxExportableModule(torch.nn.Module): + def __init__(self, model: torch.nn.Module): + super().__init__() + self.model = model + + def forward(self, logits: torch.FloatTensor): + return torch.argmax(logits, dim=-1) + + class Seq2SeqLMExportableModule(torch.nn.Module): def __init__( self, @@ -858,6 +867,17 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi return exported_decoder + def _export_sampler(self, logits): + sampler = ArgmaxExportableModule(self.model).to(self.model.device).eval() + with torch.no_grad(): + exported_sampler = torch.export.export( + sampler, + (logits,), + dynamic_shapes=None, + strict=True, + ) + return exported_sampler + def export( self, encoder_input_ids=None, @@ -899,9 +919,14 @@ def export( example_cache_position, ) + self.exported_sampler = self._export_sampler( + torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device) + ) + return { "encoder": self.exported_encoder, # Not called "text_encoder" because the encoder could be non-text too, e.g. Whisper. "text_decoder": self.exported_decoder, + "sampler": self.exported_sampler, } def generate(self, prompt_token_ids, max_new_tokens):