Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,45 @@ def _resolve_image_path(self, path: str) -> str:
raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
return os.path.join(self.image_base_dir, path)

def _count_images_in_tokens(self, token_ids: List[int]) -> int:
"""Count the number of complete image regions in a token ID sequence.

Image regions are identified by finding ``vision_start_token_id``
followed by ``image_token_id``, matching the detection logic used by
``get_rope_index`` in the Qwen2-VL / Qwen2.5-VL model implementation.
This is needed to reconcile ``image_grid_thw`` with truncated prompts
so that mRoPE position IDs are computed correctly.

Args:
token_ids: List of token IDs (possibly truncated).

Returns:
Number of image regions found in the token sequence, or ``-1`` if
the required special-token IDs could not be resolved (in which case
the caller should fall back to the original image count).
"""
# Resolve image_token_id from the processor (set during __init__)
image_token_id = getattr(self.processor, "image_token_id", None)
if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"):
image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")

# Resolve vision_start_token_id -- not stored on the processor, so we
# try the tokenizer first and fall back to the well-known default.
vision_start_token_id = None
if hasattr(self.tokenizer, "convert_tokens_to_ids"):
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if vision_start_token_id is None:
vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default

if image_token_id is None:
return -1

count = 0
for i in range(len(token_ids) - 1):
if token_ids[i] == vision_start_token_id and token_ids[i + 1] == image_token_id:
count += 1
return count

def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]:
"""Compute image_grid_thw from image URLs for M-RoPE computation.

Expand Down Expand Up @@ -907,9 +946,17 @@ def get_train_data_batch(
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)

# Compute image_grid_thw for this triplet using image_urls from prompt
# Compute image_grid_thw for this triplet using image_urls from prompt.
# After prompt truncation, some image tokens may have been removed,
# so we must reconcile image_urls with the actual images remaining
# in the (possibly truncated) prompt to avoid shape mismatches in
# get_rope_index when computing mRoPE position IDs.
if self._use_mrope:
image_urls = trace.get("image_urls", [])
if image_urls:
n_images_in_tokens = self._count_images_in_tokens(prompt_ids)
if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls):
image_urls = image_urls[:n_images_in_tokens]
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))

elif self.trace_aggregator.get("level", "transition") == "trajectory":
Expand Down