From 77040b8b74695a26d9932ee79ecfabb491b8bec4 Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Fri, 20 Jun 2025 13:55:21 -0400 Subject: [PATCH 1/2] Merged ReferenceRotaryEncoder func into RotaryEncoder * Added impl flag to choose between the llama and reference implementation --- src/fairseq2/nn/_position_encoder.py | 67 ++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/nn/_position_encoder.py b/src/fairseq2/nn/_position_encoder.py index b121ae7d2..c3a85e50e 100644 --- a/src/fairseq2/nn/_position_encoder.py +++ b/src/fairseq2/nn/_position_encoder.py @@ -311,6 +311,7 @@ class RotaryEncoder(PositionEncoder): max_seq_len: int theta: float freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None + impl: str def __init__( self, @@ -320,6 +321,7 @@ def __init__( theta: float = 10_000.0, freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None = None, device: Device | None = None, + impl: str = "llama" ) -> None: """ :param encoding_dim: The dimensionality of positional encodings. The @@ -334,8 +336,14 @@ def __init__( expected for the callable to return a :class:`~torch.Tensor` holding the frequency table. If ``None``, the frequencies will be initialized as described in the reference paper. + :param impl: Changes the embedding dimension grouping by using consecutive + tensors as a real/img pair ("llama") or using the split-half pairing ("reference"). + Example: E = 8: [1,2,3,4,5,6,7,8] + - "llama": [(1,2), (3,4), (5,6), (7,8)] + - "reference": [(1,5), (2,6), (3,7), (4,8)] :raise ValueError: when ``encoding_dim`` is not even. + :raise ValueError: when ``impl`` is not a valid implementation selection """ super().__init__(encoding_dim) @@ -344,6 +352,12 @@ def __init__( f"`encoding_dim` must be even, but is {encoding_dim} instead." ) + if impl not in ["llama", "reference"]: + raise ValueError( + f"`impl` must be one of [\"llama\", \"reference\"], but is {impl} instead." + ) + + # (S+1, E / 2, 2) freqs = torch.empty( (max_seq_len + 1, encoding_dim // 2, 2), device=device, dtype=torch.float32 ) @@ -356,6 +370,8 @@ def __init__( self.freqs_init_fn = freqs_init_fn + self.impl = impl + self.reset_parameters() def reset_parameters(self) -> None: @@ -430,8 +446,11 @@ def forward( # (S, E / 2) -> (1, S, E / 2) complex_freqs = complex_freqs.unsqueeze(0) + + if self.impl == "reference": + seqs = self._split_to_consecutive_layout(tensor=seqs) - # ([N], S, *, E) -> ([N], S, *, E / 2, 2) + # ([N], S, *, E) -> ([N], S, *, E / 2, 2) seqs = seqs.unflatten(-1, (-1, 2)) # ([N], S, *, E / 2, 2) -> ([N], S, *, E / 2) @@ -445,7 +464,47 @@ def forward( # ([N], S, *, E / 2) -> ([N], S, *, E) fp32_seqs = torch.view_as_real(complex_seqs).flatten(-2) + if self.impl == "reference": + fp32_seqs = self._consecutive_to_split_layout(tensor=fp32_seqs) + return fp32_seqs.type_as(seqs) + + def _consecutive_to_split_layout(self,tensor: torch.Tensor) -> torch.Tensor: + """ + Transforms consecutive pairs to split layout: [1,2,3,4,5,6,7,8] -> [1,3,5,7,2,4,6,8] + """ + original_shape = tensor.shape + encoding_dim = original_shape[-1] + half_dim = encoding_dim // 2 + + # (*, E) -> (*, E / 2, 2) + pairs = tensor.view(*original_shape[:-1], half_dim, 2) + + # (*, E / 2) + real_parts = pairs[..., 0] + # (*, E / 2) + imag_parts = pairs[..., 1] + + # (*, E / 2) -> (*, E) + return torch.cat([real_parts, imag_parts], dim=-1) + + def _split_to_consecutive_layout(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Transforms split into consecutive layout: [1,3,5,7,2,4,6,8] -> [1,2,3,4,5,6,7,8] + """ + original_shape = tensor.shape + encoding_dim = original_shape[-1] + half_dim = encoding_dim // 2 + + # (*, E) -> (*, E / 2) + real_parts = tensor[..., :half_dim] + # (*, E) -> (*, E / 2) + imag_parts = tensor[..., half_dim:] + + # (*, E / 2, 2) -> (*, E) + pairs = torch.stack([real_parts, imag_parts], dim=-1) + # tuples to original view + return pairs.view(*original_shape) @override def extra_repr(self) -> str: @@ -532,10 +591,10 @@ def reset_non_persistent_buffers(self) -> None: encoding_dim = self.encoding_dim - # (E) + # (E / 2) indices = torch.arange(encoding_dim // 2, device=device, dtype=dtype) - # (E) -> (1, E) + # (E / 2) -> (1, E / 2) indices = indices.unsqueeze(0) # (S) @@ -544,7 +603,7 @@ def reset_non_persistent_buffers(self) -> None: # (S, 1) steps = steps.unsqueeze(1) - # (S, 1) x (1, E) -> (S, E) + # (S, 1) x (1, E / 2) -> (S, E / 2) table = torch.matmul(steps, self.theta ** (-2.0 * indices / encoding_dim)) cos = torch.cos(table) From eb69a67ca78cdea45688e2515b2491bbf2e8aa00 Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Fri, 20 Jun 2025 18:27:01 -0400 Subject: [PATCH 2/2] Merged RotaryEncoder func into ReferenceRotaryEncoder * Added impl flag to choose between the llama and reference implementation --- src/fairseq2/nn/_position_encoder.py | 65 +++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/src/fairseq2/nn/_position_encoder.py b/src/fairseq2/nn/_position_encoder.py index c3a85e50e..56cc01bbd 100644 --- a/src/fairseq2/nn/_position_encoder.py +++ b/src/fairseq2/nn/_position_encoder.py @@ -336,11 +336,11 @@ def __init__( expected for the callable to return a :class:`~torch.Tensor` holding the frequency table. If ``None``, the frequencies will be initialized as described in the reference paper. - :param impl: Changes the embedding dimension grouping by using consecutive + :param impl: Changes the embedding dimension ordering by using consecutive tensors as a real/img pair ("llama") or using the split-half pairing ("reference"). Example: E = 8: [1,2,3,4,5,6,7,8] - - "llama": [(1,2), (3,4), (5,6), (7,8)] - - "reference": [(1,5), (2,6), (3,7), (4,8)] + - "llama": [(1,2), (3,4), (5,6), (7,8)] := [real0, imag0, real1, imag1, real2, imag2, real3, imag3] + - "reference": [(1,5), (2,6), (3,7), (4,8)] := [real0, real1, real2, real3, imag0, imag1, imag2, imag3] :raise ValueError: when ``encoding_dim`` is not even. :raise ValueError: when ``impl`` is not a valid implementation selection @@ -472,7 +472,7 @@ def forward( def _consecutive_to_split_layout(self,tensor: torch.Tensor) -> torch.Tensor: """ Transforms consecutive pairs to split layout: [1,2,3,4,5,6,7,8] -> [1,3,5,7,2,4,6,8] - """ + """ original_shape = tensor.shape encoding_dim = original_shape[-1] half_dim = encoding_dim // 2 @@ -487,7 +487,7 @@ def _consecutive_to_split_layout(self,tensor: torch.Tensor) -> torch.Tensor: # (*, E / 2) -> (*, E) return torch.cat([real_parts, imag_parts], dim=-1) - + def _split_to_consecutive_layout(self, tensor: torch.Tensor) -> torch.Tensor: """ Transforms split into consecutive layout: [1,3,5,7,2,4,6,8] -> [1,2,3,4,5,6,7,8] @@ -534,6 +534,7 @@ class ReferenceRotaryEncoder(PositionEncoder): sin_freqs: Tensor max_seq_len: int theta: float + impl: str def __init__( self, @@ -542,6 +543,7 @@ def __init__( *, theta: float = 10_000.0, device: Device | None = None, + impl: str = "reference", ) -> None: """ :param encoding_dim: The dimensionality of positional encodings. The @@ -551,8 +553,14 @@ def __init__( Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`. :param theta: The coefficient of the long-term decay as described in section 3.3 of the reference paper. + :param impl: Changes the embedding dimension ordering by using consecutive + tensors as a real/img pair ("llama") or using the split-half pairing ("reference"). + Example: E = 8: [1,2,3,4,5,6,7,8] + - "llama": [(1,2), (3,4), (5,6), (7,8)] := [real0, imag0, real1, imag1, real2, imag2, real3, imag3] + - "reference": [(1,5), (2,6), (3,7), (4,8)] := [real0, real1, real2, real3, imag0, imag1, imag2, imag3] :raise ValueError: when ``encoding_dim`` is not even. + :raise ValueError: when ``impl`` is not a valid implementation selection. """ super().__init__(encoding_dim) @@ -561,6 +569,11 @@ def __init__( f"`encoding_dim` must be even, but is {encoding_dim} instead." ) + if impl not in ["reference", "llama"]: + raise ValueError( + f"`impl` must be one of [\"reference\", \"llama\"], but is {impl} instead." + ) + cos_freqs = torch.empty( (max_seq_len + 1, encoding_dim), device=device, dtype=torch.float32 ) @@ -576,6 +589,8 @@ def __init__( self.theta = theta + self.impl = impl + self.reset_parameters() def reset_parameters(self) -> None: @@ -609,11 +624,21 @@ def reset_non_persistent_buffers(self) -> None: cos = torch.cos(table) sin = torch.sin(table) - self.cos_freqs[1:, : encoding_dim // 2] = cos - self.cos_freqs[1:, encoding_dim // 2 :] = cos - - self.sin_freqs[1:, : encoding_dim // 2] = sin - self.sin_freqs[1:, encoding_dim // 2 :] = sin + if self.impl == "reference": + # Split-half layout: [real0, real1, real2, real3, imag0, imag1, imag2, imag3] + self.cos_freqs[1:, : encoding_dim // 2] = cos + self.cos_freqs[1:, encoding_dim // 2 :] = cos + + self.sin_freqs[1:, : encoding_dim // 2] = sin + self.sin_freqs[1:, encoding_dim // 2 :] = sin + else: # llama + # Consecutive layout: [real0, imag0, real1, imag1, real2, imag2, real3, imag3] + for i in range(encoding_dim // 2): + self.cos_freqs[1:, 2*i] = cos[:, i] + self.cos_freqs[1:, 2*i + 1] = cos[:, i] + + self.sin_freqs[1:, 2*i] = sin[:, i] + self.sin_freqs[1:, 2*i + 1] = sin[:, i] @override def forward( @@ -666,25 +691,41 @@ def forward( fp32_seqs = seqs.float() - fp32_rotated_seqs = self._rotate_half_way(fp32_seqs) + if self.impl == "reference": + fp32_rotated_seqs = self._rotate_half_way(fp32_seqs) + else: # llama + fp32_rotated_seqs = self._reorder_to_consecutive_pairs(fp32_seqs) fp32_seqs = (fp32_seqs * cos_freqs) + (fp32_rotated_seqs * sin_freqs) return fp32_seqs.type_as(seqs) def _rotate_half_way(self, seqs: Tensor) -> Tensor: + """Rotation for split-half layout: [1,2,3,4,5,6,7,8] -> [-5,-6,-7,-8,1,2,3,4]""" half1 = seqs[..., : self.encoding_dim // 2] half2 = seqs[..., self.encoding_dim // 2 :] return torch.cat((-half2, half1), dim=-1) + def _reorder_to_consecutive_pairs(self, seqs: Tensor) -> Tensor: + """Rotation for consecutive layout: [1,2,3,4,5,6,7,8] -> [-2,1,-4,3,-6,5,-8,7]""" + even_parts = seqs[..., 0::2] + odd_parts = seqs[..., 1::2] + + result = torch.zeros_like(seqs) + result[..., 0::2] = -odd_parts + result[..., 1::2] = even_parts + + return result + @override def extra_repr(self) -> str: """:meta private:""" return ( f"encoding_dim={self.encoding_dim}, " f"max_seq_len={self.max_seq_len}, " - f"theta={self.theta}" + f"theta={self.theta}, " + f"impl={self.impl}" )