Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,110 @@ def fake_generate(input_ids, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.0.0"),
reason="Tool parsing is not supported in transformers versions below 5.0.0",
strict=True,
)
@require_jmespath
def test_training_with_tools_keeps_masks_aligned_when_retokenization_shortens_completion(self):
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train[:3]")

def reward_func(completions, **kwargs):
return [0.0] * len(completions)

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=128,
max_steps=1,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen3MoeForCausalLM",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
tools=[multiply_tool],
)

turn = {"value": 0}

def fake_generate_single_turn(prompts):
del prompts
turn["value"] += 1
if turn["value"] == 1:
# Initial batch: sample 0 is interpreted as a tool call, others are plain completions.
prompt_ids = [[100, 101], [110, 111], [120, 121]]
completion_ids = [
[900, 901, 902, 903],
[700, 701, 702],
[710, 711, 712],
]
return prompt_ids, completion_ids, None, {}
if turn["value"] == 2:
# Tool round for the single tool-calling sample:
# prompt+completion+tool is intentionally short to trigger the edge case where
# completion_tool_length can be shorter than the previous completion length.
prompt_ids = [[100, 101, 201, 202]]
completion_ids = [[600, 601]]
return prompt_ids, completion_ids, None, {}
raise RuntimeError(f"Unexpected fake generation turn: {turn['value']}")

def fake_parse_response(processing_class, ids):
del processing_class
if ids and ids[0] == 900:
return {
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "multiply_tool",
"arguments": {"a": 3, "b": 4},
},
}
],
}
return {"role": "assistant", "content": "ok"}

def fake_get_per_token_logps_and_entropies(
model,
input_ids,
attention_mask,
logits_to_keep,
*args,
compute_entropy=False,
**kwargs,
):
del model, attention_mask, args, kwargs
shape = (input_ids.size(0), logits_to_keep)
logps = torch.zeros(shape, dtype=torch.float32, device=input_ids.device, requires_grad=True)
entropies = torch.zeros(shape, dtype=torch.float32, device=input_ids.device) if compute_entropy else None
return logps, entropies

observed_lengths = {}
original_tool_call_loop = trainer._tool_call_loop

def wrapped_tool_call_loop(*args, **kwargs):
out = original_tool_call_loop(*args, **kwargs)
tool_mask, _completions, completion_ids, _logprobs, *_ = out
observed_lengths["completion"] = [len(ids) for ids in completion_ids]
observed_lengths["tool"] = [len(mask) for mask in tool_mask]
return out

with (
patch.object(trainer, "_generate_single_turn", side_effect=fake_generate_single_turn),
patch("trl.trainer.grpo_trainer.parse_response", side_effect=fake_parse_response),
patch.object(trainer, "_get_per_token_logps_and_entropies", side_effect=fake_get_per_token_logps_and_entropies),
patch.object(trainer, "_tool_call_loop", side_effect=wrapped_tool_call_loop),
):
trainer.train()

assert observed_lengths["completion"] == observed_lengths["tool"]

def test_mismatched_reward_processing_classes_length(self):
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
30 changes: 24 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,9 +1414,16 @@ async def _run_async_tools(async_coros):
prompt_length = len(prompt_ids[idx_with_tool])
ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length]
completion_ids[idx_with_tool] = ct
tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool]))
# Keep tool_mask aligned with completion_ids even if ct becomes shorter.
if len(tool_mask[idx_with_tool]) > len(ct):
tool_mask[idx_with_tool] = tool_mask[idx_with_tool][: len(ct)]
else:
tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool]))
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool]))
if len(logprobs[idx_with_tool]) > len(ct):
logprobs[idx_with_tool] = logprobs[idx_with_tool][: len(ct)]
else:
logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool]))
# Keep only non-overlong items for further processing
idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o]
prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o]
Expand Down Expand Up @@ -1459,12 +1466,23 @@ async def _run_async_tools(async_coros):
idx_with_tool = idxs_with_tool[idx]
prompt_completion_tool_length = len(prompt_completion_tool_ids[idx])
prompt_length = len(prompt_ids[idx_with_tool])
completion_length = len(completion_ids[idx_with_tool])
post_tool_length = len(post_tool_ids[idx])
tool_length = prompt_completion_tool_length - prompt_length - completion_length
tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length

# Build mask for completion_tool_ids (tool tokens should be 0), then append post-tool model tokens (1).
completion_tool_length = prompt_completion_tool_length - prompt_length
if len(tool_mask[idx_with_tool]) > completion_tool_length:
tool_mask[idx_with_tool] = tool_mask[idx_with_tool][:completion_tool_length]
else:
tool_mask[idx_with_tool] += [0] * (completion_tool_length - len(tool_mask[idx_with_tool]))
tool_mask[idx_with_tool] += [1] * post_tool_length

if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx]
# Tool-result tokens are external and have no model logprob, so they get 0.0.
if len(logprobs[idx_with_tool]) > completion_tool_length:
logprobs[idx_with_tool] = logprobs[idx_with_tool][:completion_tool_length]
else:
logprobs[idx_with_tool] += [0.0] * (completion_tool_length - len(logprobs[idx_with_tool]))
logprobs[idx_with_tool] += post_tool_logprobs[idx]

# Update completion_ids with the new completions (after tool execution)
for idx in range(len(idxs_with_tool)):
Expand Down