From ca0be5aa5992b0790831feac02c2bcf9a777b062 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Tue, 10 Feb 2026 03:29:33 +0530 Subject: [PATCH] Fix mRoPE position ID crash when Qwen2-VL prompts are truncated When training Qwen2.5-VL with agent-lightning + verl, prompt truncation changes the token count but image_grid_thw is computed from the original (untruncated) image_urls. This causes get_rope_index to fail with a shape mismatch because it finds fewer image tokens in the truncated input_ids than entries in image_grid_thw. After prompt truncation, count remaining image regions in the truncated token sequence and slice image_urls to match before computing image_grid_thw, ensuring consistency between the token content and the mRoPE spatial metadata. Fixes #441 --- agentlightning/verl/daemon.py | 49 ++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..ca0fa3599 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -310,6 +310,45 @@ def _resolve_image_path(self, path: str) -> str: raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.") return os.path.join(self.image_base_dir, path) + def _count_images_in_tokens(self, token_ids: List[int]) -> int: + """Count the number of complete image regions in a token ID sequence. + + Image regions are identified by finding ``vision_start_token_id`` + followed by ``image_token_id``, matching the detection logic used by + ``get_rope_index`` in the Qwen2-VL / Qwen2.5-VL model implementation. + This is needed to reconcile ``image_grid_thw`` with truncated prompts + so that mRoPE position IDs are computed correctly. + + Args: + token_ids: List of token IDs (possibly truncated). + + Returns: + Number of image regions found in the token sequence, or ``-1`` if + the required special-token IDs could not be resolved (in which case + the caller should fall back to the original image count). + """ + # Resolve image_token_id from the processor (set during __init__) + image_token_id = getattr(self.processor, "image_token_id", None) + if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"): + image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + # Resolve vision_start_token_id -- not stored on the processor, so we + # try the tokenizer first and fall back to the well-known default. + vision_start_token_id = None + if hasattr(self.tokenizer, "convert_tokens_to_ids"): + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if vision_start_token_id is None: + vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default + + if image_token_id is None: + return -1 + + count = 0 + for i in range(len(token_ids) - 1): + if token_ids[i] == vision_start_token_id and token_ids[i + 1] == image_token_id: + count += 1 + return count + def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]: """Compute image_grid_thw from image URLs for M-RoPE computation. @@ -907,9 +946,17 @@ def get_train_data_batch( rollout_id_list.append(rollout_id) turn_index_list.append(turn_index) - # Compute image_grid_thw for this triplet using image_urls from prompt + # Compute image_grid_thw for this triplet using image_urls from prompt. + # After prompt truncation, some image tokens may have been removed, + # so we must reconcile image_urls with the actual images remaining + # in the (possibly truncated) prompt to avoid shape mismatches in + # get_rope_index when computing mRoPE position IDs. if self._use_mrope: image_urls = trace.get("image_urls", []) + if image_urls: + n_images_in_tokens = self._count_images_in_tokens(prompt_ids) + if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls): + image_urls = image_urls[:n_images_in_tokens] image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) elif self.trace_aggregator.get("level", "transition") == "trajectory":