Conversation
| if self.padding_free: | ||
| shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids) | ||
| shift_labels = input_ids[..., 1:].contiguous() | ||
| shift_completion_mask = completion_mask[..., 1:].contiguous() | ||
| per_token_logps = selective_log_softmax(shift_logits, shift_labels) |
There was a problem hiding this comment.
🔴 Dimension mismatch between restored logits and shift_labels when padding_free=True with pad_to_multiple_of
When both padding_free=True and pad_to_multiple_of is set, restore_padding_from_flattened produces a tensor whose sequence dimension is max_unpadded_seq_len - 1, while shift_labels and shift_completion_mask (derived from the original padded input_ids) have sequence dimension padded_seq_len - 1. Because pad_to_multiple_of can make padded_seq_len > max_unpadded_seq_len, these dimensions will not match, causing a runtime crash in selective_log_softmax (or the liger loss).
Detailed explanation with concrete example
Consider a batch with 2 sequences of lengths 3 and 2, and pad_to_multiple_of=4:
- DataCollator pads to length 4 →
input_idsshape:(2, 4) shift_labels = input_ids[..., 1:]→ shape(2, 3)- Flattening removes padding → 5 non-padding tokens
- Model forward →
(1, 5, vocab), after[..., :-1, :]→(1, 4, vocab) restore_padding_from_flattenedremoves cross-boundary elements and pads back → shape(2, 2, vocab)(max unpadded length 3 minus 1 = 2)selective_log_softmax(shift_logits, shift_labels)receives shapes(2, 2, vocab)and(2, 3)→ crash
This affects three methods: _compute_loss at trl/trainer/dpo_trainer.py:1215, compute_ref_log_probs at trl/trainer/dpo_trainer.py:1080, and _compute_loss_liger at trl/trainer/dpo_trainer.py:1146-1147. There is no validation preventing the user from setting both padding_free=True and pad_to_multiple_of.
Impact: Any user combining padding_free=True with a non-None pad_to_multiple_of will hit a dimension mismatch crash during training whenever pad_to_multiple_of rounds up beyond the longest sequence in the batch.
Prompt for agents
In trl/trainer/dpo_trainer.py, around where the padding_free and vision dataset checks are done (near line 635-638), add a validation that raises an error when both padding_free=True and pad_to_multiple_of is not None, since restore_padding_from_flattened reconstructs tensors based on actual (unpadded) sequence lengths, which will be shorter than the padded_seq_len when pad_to_multiple_of rounds up. For example, add after line 638:
if self.padding_free and args.pad_to_multiple_of is not None:
raise ValueError(
"Padding-free training is not compatible with `pad_to_multiple_of`. "
"Please set `pad_to_multiple_of=None` when using `padding_free=True`."
)
Alternatively, the fix could be applied inside the loss computation methods (_compute_loss, compute_ref_log_probs, _compute_loss_liger) by truncating shift_labels and shift_completion_mask to match the restored logits' sequence length after restore_padding_from_flattened is called.
Was this helpful? React with 👍 or 👎 to provide feedback.
| keep_mask = flat_position_ids[:, 1:].ne(0).squeeze(0) | ||
| tensor = tensor.squeeze(0)[keep_mask] | ||
| starts = flat_position_ids.squeeze(0).eq(0).nonzero(as_tuple=True)[0] | ||
| ends = torch.cat((starts[1:], starts.new_tensor([flat_position_ids.size(1)]))) | ||
| split_lengths = (ends - starts - 1).clamp_min(0).tolist() | ||
| return pad(list(tensor.split(split_lengths, dim=0)), padding_value=padding_value) |
There was a problem hiding this comment.
📝 Info: restore_padding_from_flattened cross-sequence boundary filtering is correct but subtle
The keep_mask at trl/trainer/dpo_trainer.py:142 filters elements from the shifted ([..., :-1, :]) tensor using flat_position_ids[:, 1:].ne(0). This removes exactly the cross-sequence boundary predictions (where a token from sequence N predicts the first token of sequence N+1). For example, with flat_position_ids = [[0,1,2,3,0,1,0,1,2]], the shifted position IDs are [1,2,3,0,1,0,1,2], and .ne(0) correctly identifies indices 3 and 5 as boundaries to discard. The split_lengths at line 146 use (ends - starts - 1).clamp_min(0) which correctly yields per-sequence shifted lengths (original_length - 1). This logic is correct but non-obvious — a comment explaining why keep_mask filters cross-boundary predictions would improve maintainability.
Was this helpful? React with 👍 or 👎 to provide feedback.
| if self.padding_free: | ||
| shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids) | ||
| shift_labels = input_ids[..., 1:].contiguous() | ||
| shift_completion_mask = completion_mask[..., 1:].contiguous() | ||
| per_token_logps = selective_log_softmax(shift_logits, shift_labels) |
There was a problem hiding this comment.
📝 Info: Padding value of 0 in restored logits is safe due to downstream masking
In _compute_loss at line 1212, restore_padding_from_flattened(shift_logits, flat_position_ids) uses the default padding_value=0, meaning restored padding positions get all-zero logits. While this produces incorrect log-probabilities at those positions (approximately -log(vocab_size) after log_softmax), the subsequent masking at line 1216 (per_token_logps[shift_completion_mask == 0] = 0.0) zeros them out before summation. Similarly in compute_ref_log_probs at line 1081. So the choice of padding_value=0 is safe, though it could be confusing to future readers.
(Refers to lines 1211-1216)
Was this helpful? React with 👍 or 👎 to provide feedback.
| def test_restore_padding_from_flattened(self): | ||
| flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]]) | ||
| flattened = torch.arange(16, dtype=torch.float32).view(1, 8, 2) | ||
|
|
||
| restored = restore_padding_from_flattened(flattened, flat_position_ids, padding_value=-100) | ||
|
|
||
| expected = torch.tensor( | ||
| [ | ||
| [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], | ||
| [[8.0, 9.0], [-100.0, -100.0], [-100.0, -100.0]], | ||
| [[12.0, 13.0], [14.0, 15.0], [-100.0, -100.0]], | ||
| ] | ||
| ) | ||
| torch.testing.assert_close(restored, expected) |
There was a problem hiding this comment.
📝 Info: Test for restore_padding_from_flattened only covers multi-token sequences
The test at tests/test_dpo_trainer.py:164-177 covers sequences of lengths 4, 2, and 3, but does not test edge cases such as a sequence with exactly 1 token (which would produce 0 shifted tokens and an empty split). While tracing the logic shows clamp_min(0) at trl/trainer/dpo_trainer.py:146 handles this gracefully (producing an empty tensor that gets padded), and such short sequences are unlikely in DPO training, an explicit test would increase confidence. Similarly, there's no test for the case where all sequences have the same length.
Was this helpful? React with 👍 or 👎 to provide feedback.
| if self.padding_free: | ||
| flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) | ||
| model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} | ||
| else: | ||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} |
There was a problem hiding this comment.
📝 Info: No attention_mask passed to model in padding_free mode relies on model auto-inferring it from position_ids
When padding_free=True, the model_kwargs at trl/trainer/dpo_trainer.py:1199 include position_ids but omit attention_mask. This relies on the model's flash attention implementation inferring sequence boundaries from position_ids (detecting resets to 0 as new sequence starts). This is the standard convention for HuggingFace flash attention implementations and is correct for the supported flash attention variants listed in FLASH_ATTENTION_VARIANTS. However, if a user bypasses the warning and uses a non-flash attention implementation, the lack of attention_mask would cause incorrect self-attention (all tokens attending to all other tokens across sequence boundaries).
Was this helpful? React with 👍 or 👎 to provide feedback.
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.