Skip to content

DPO padding-free#5141

Draft
qgallouedec wants to merge 3 commits intomainfrom
dpo-padding-free
Draft

DPO padding-free#5141
qgallouedec wants to merge 3 commits intomainfrom
dpo-padding-free

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 21, 2026

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Open with Devin

Copy link
Member Author

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 5 potential issues.

View 1 additional finding in Devin Review.

Open in Devin Review

Comment on lines +1211 to 1215
if self.padding_free:
shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids)
shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Dimension mismatch between restored logits and shift_labels when padding_free=True with pad_to_multiple_of

When both padding_free=True and pad_to_multiple_of is set, restore_padding_from_flattened produces a tensor whose sequence dimension is max_unpadded_seq_len - 1, while shift_labels and shift_completion_mask (derived from the original padded input_ids) have sequence dimension padded_seq_len - 1. Because pad_to_multiple_of can make padded_seq_len > max_unpadded_seq_len, these dimensions will not match, causing a runtime crash in selective_log_softmax (or the liger loss).

Detailed explanation with concrete example

Consider a batch with 2 sequences of lengths 3 and 2, and pad_to_multiple_of=4:

  • DataCollator pads to length 4 → input_ids shape: (2, 4)
  • shift_labels = input_ids[..., 1:] → shape (2, 3)
  • Flattening removes padding → 5 non-padding tokens
  • Model forward → (1, 5, vocab), after [..., :-1, :](1, 4, vocab)
  • restore_padding_from_flattened removes cross-boundary elements and pads back → shape (2, 2, vocab) (max unpadded length 3 minus 1 = 2)
  • selective_log_softmax(shift_logits, shift_labels) receives shapes (2, 2, vocab) and (2, 3)crash

This affects three methods: _compute_loss at trl/trainer/dpo_trainer.py:1215, compute_ref_log_probs at trl/trainer/dpo_trainer.py:1080, and _compute_loss_liger at trl/trainer/dpo_trainer.py:1146-1147. There is no validation preventing the user from setting both padding_free=True and pad_to_multiple_of.

Impact: Any user combining padding_free=True with a non-None pad_to_multiple_of will hit a dimension mismatch crash during training whenever pad_to_multiple_of rounds up beyond the longest sequence in the batch.

Prompt for agents
In trl/trainer/dpo_trainer.py, around where the padding_free and vision dataset checks are done (near line 635-638), add a validation that raises an error when both padding_free=True and pad_to_multiple_of is not None, since restore_padding_from_flattened reconstructs tensors based on actual (unpadded) sequence lengths, which will be shorter than the padded_seq_len when pad_to_multiple_of rounds up. For example, add after line 638:

if self.padding_free and args.pad_to_multiple_of is not None:
    raise ValueError(
        "Padding-free training is not compatible with `pad_to_multiple_of`. "
        "Please set `pad_to_multiple_of=None` when using `padding_free=True`."
    )

Alternatively, the fix could be applied inside the loss computation methods (_compute_loss, compute_ref_log_probs, _compute_loss_liger) by truncating shift_labels and shift_completion_mask to match the restored logits' sequence length after restore_padding_from_flattened is called.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +142 to +147
keep_mask = flat_position_ids[:, 1:].ne(0).squeeze(0)
tensor = tensor.squeeze(0)[keep_mask]
starts = flat_position_ids.squeeze(0).eq(0).nonzero(as_tuple=True)[0]
ends = torch.cat((starts[1:], starts.new_tensor([flat_position_ids.size(1)])))
split_lengths = (ends - starts - 1).clamp_min(0).tolist()
return pad(list(tensor.split(split_lengths, dim=0)), padding_value=padding_value)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: restore_padding_from_flattened cross-sequence boundary filtering is correct but subtle

The keep_mask at trl/trainer/dpo_trainer.py:142 filters elements from the shifted ([..., :-1, :]) tensor using flat_position_ids[:, 1:].ne(0). This removes exactly the cross-sequence boundary predictions (where a token from sequence N predicts the first token of sequence N+1). For example, with flat_position_ids = [[0,1,2,3,0,1,0,1,2]], the shifted position IDs are [1,2,3,0,1,0,1,2], and .ne(0) correctly identifies indices 3 and 5 as boundaries to discard. The split_lengths at line 146 use (ends - starts - 1).clamp_min(0) which correctly yields per-sequence shifted lengths (original_length - 1). This logic is correct but non-obvious — a comment explaining why keep_mask filters cross-boundary predictions would improve maintainability.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +1211 to 1215
if self.padding_free:
shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids)
shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: Padding value of 0 in restored logits is safe due to downstream masking

In _compute_loss at line 1212, restore_padding_from_flattened(shift_logits, flat_position_ids) uses the default padding_value=0, meaning restored padding positions get all-zero logits. While this produces incorrect log-probabilities at those positions (approximately -log(vocab_size) after log_softmax), the subsequent masking at line 1216 (per_token_logps[shift_completion_mask == 0] = 0.0) zeros them out before summation. Similarly in compute_ref_log_probs at line 1081. So the choice of padding_value=0 is safe, though it could be confusing to future readers.

(Refers to lines 1211-1216)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +164 to +177
def test_restore_padding_from_flattened(self):
flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]])
flattened = torch.arange(16, dtype=torch.float32).view(1, 8, 2)

restored = restore_padding_from_flattened(flattened, flat_position_ids, padding_value=-100)

expected = torch.tensor(
[
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]],
[[8.0, 9.0], [-100.0, -100.0], [-100.0, -100.0]],
[[12.0, 13.0], [14.0, 15.0], [-100.0, -100.0]],
]
)
torch.testing.assert_close(restored, expected)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: Test for restore_padding_from_flattened only covers multi-token sequences

The test at tests/test_dpo_trainer.py:164-177 covers sequences of lengths 4, 2, and 3, but does not test edge cases such as a sequence with exactly 1 token (which would produce 0 shifted tokens and an empty split). While tracing the logic shows clamp_min(0) at trl/trainer/dpo_trainer.py:146 handles this gracefully (producing an empty tensor that gets padded), and such short sequences are unlikely in DPO training, an explicit test would increase confidence. Similarly, there's no test for the case where all sequences have the same length.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +1197 to +1201
if self.padding_free:
flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask)
model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False}
else:
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: No attention_mask passed to model in padding_free mode relies on model auto-inferring it from position_ids

When padding_free=True, the model_kwargs at trl/trainer/dpo_trainer.py:1199 include position_ids but omit attention_mask. This relies on the model's flash attention implementation inferring sequence boundaries from position_ids (detecting resets to 0 as new sequence starts). This is the standard convention for HuggingFace flash attention implementations and is correct for the supported flash attention variants listed in FLASH_ATTENTION_VARIANTS. However, if a user bypasses the warning and uses a non-flash attention implementation, the lack of attention_mask would cause incorrect self-attention (all tokens attending to all other tokens across sequence boundaries).

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant