diff --git a/llmfoundry/models/llama/__init__.py b/llmfoundry/models/llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llmfoundry/models/llama/attention.py b/llmfoundry/models/llama/attention.py new file mode 100644 index 0000000..30515e1 --- /dev/null +++ b/llmfoundry/models/llama/attention.py @@ -0,0 +1,106 @@ +from flash_attn import flash_attn_func +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from .liger_rope import LigerRopeFunction +from .config import LlamaConfig + +class LlamaAttention(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_attention_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.register_buffer( + "cos_cached", + self._compute_rope_embeddings( + self.max_position_embeddings, + self.head_dim, + self.rope_theta, + dtype=torch.float32, + device=self.q_proj.weight.device, + )[0], + persistent=False, + ) + self.register_buffer( + "sin_cached", + self._compute_rope_embeddings( + self.max_position_embeddings, + self.head_dim, + self.rope_theta, + dtype=torch.float32, + device=self.q_proj.weight.device, + )[1], + persistent=False, + ) + + def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) + t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + return cos.unsqueeze(0), sin.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # In B S (H D) + bsz, seq_len, _ = hidden_states.size() + + if position_ids is None: + position_ids = torch.arange(seq_len, device=hidden_states.device) + position_ids = repeat(position_ids, 'l -> b l', b=bsz) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim) + key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) + value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim) + + # Slice off position specific rope freqs from the cached freqs + cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim] + sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim] + + query_states, key_states = LigerRopeFunction.apply( + query_states, + key_states, + cos.squeeze(0), + sin.squeeze(0), + position_ids + ) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=0.0, + causal=attention_mask is None + ) + + attn_output = rearrange(attn_output, "b s h d -> b s (h d)") + return self.o_proj(attn_output) \ No newline at end of file diff --git a/llmfoundry/models/llama/config.py b/llmfoundry/models/llama/config.py new file mode 100644 index 0000000..d5e43b3 --- /dev/null +++ b/llmfoundry/models/llama/config.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +@dataclass +class LlamaConfig: + hidden_size: int = 576 + num_attention_heads: int = 9 + num_key_value_heads: int = 3 + num_hidden_layers: int = 30 + intermediate_size: int = 1536 + hidden_act: str = "silu" + rms_norm_eps: float = 1e-5 + vocab_size: int = 49152 + max_position_embeddings: int = 8192 + rope_theta: int = 100000 + tie_word_embeddings: bool = False \ No newline at end of file diff --git a/llmfoundry/models/llama/decoder.py b/llmfoundry/models/llama/decoder.py new file mode 100644 index 0000000..aa4ec6e --- /dev/null +++ b/llmfoundry/models/llama/decoder.py @@ -0,0 +1,40 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .mlp import LlamaMLP +from .config import LlamaConfig +from .rms_norm import LlamaRMSNorm +from .attention import LlamaAttention + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.self_attn = LlamaAttention(config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states \ No newline at end of file diff --git a/llmfoundry/models/llama/front_end.py b/llmfoundry/models/llama/front_end.py new file mode 100644 index 0000000..9198d68 --- /dev/null +++ b/llmfoundry/models/llama/front_end.py @@ -0,0 +1,85 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from .config import LlamaConfig +from .model import LlamaModel + +class LlamaForCausalLM(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.model = LlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Weight tying uses the head weights as the classifier for the token embeddings for both in and out. + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self._init_weights() + + def _init_weights(self): + """Initialize weights for all layers.""" + # Initialize embeddings + if hasattr(self.model, 'embed_tokens'): + nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664) + + # Initialize linear layers + for module in self.modules(): + if isinstance(module, nn.Linear): + # Xavier/Glorot initialization for weights + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + # Zero initialization for biases + nn.init.zeros_(module.bias) + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + return hidden_states, self.lm_head.weight + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + max_new_tokens: int = 30, + temperature: float = 0.0, + ) -> torch.LongTensor: + self.eval() + bsz, seq_len = input_ids.shape + + position_ids = repeat( + torch.arange(seq_len, device=input_ids.device), + 'l -> b l', + b=bsz + ) + + for _ in range(max_new_tokens): + hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids) + + # Get logits by computing hidden_states @ classifier_weights.T + next_token_logits = hidden_states[:, -1] @ classifier_weights.T + + if temperature == 0: + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + else: + scaled_logits = next_token_logits / temperature + probs = torch.softmax(scaled_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + + input_ids = torch.cat([input_ids, next_token], dim=1) + new_position_ids = position_ids[:, -1:] + 1 + position_ids = torch.cat([position_ids, new_position_ids], dim=1) + + return input_ids \ No newline at end of file diff --git a/llmfoundry/models/llama/liger_rope.py b/llmfoundry/models/llama/liger_rope.py new file mode 100644 index 0000000..f03441e --- /dev/null +++ b/llmfoundry/models/llama/liger_rope.py @@ -0,0 +1,257 @@ +import torch +import triton +import triton.language as tl + +# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py +# BSD 2-CLAUSE LICENSE +# Copyright 2024 LinkedIn Corporation +# All Rights Reserved. +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@triton.jit +def _triton_rope( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q, k, cos, sin + + +def rope_backward(dq, dk, cos, sin): + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq, dk + + +class LigerRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that + this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different + than the original RoPE paper. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + + For more details about the rotation matrix used here, please refer to: + https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/llmfoundry/models/llama/mlp.py b/llmfoundry/models/llama/mlp.py new file mode 100644 index 0000000..e2147f9 --- /dev/null +++ b/llmfoundry/models/llama/mlp.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import LlamaConfig + +class LlamaMLP(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) \ No newline at end of file diff --git a/llmfoundry/models/llama/model.py b/llmfoundry/models/llama/model.py new file mode 100644 index 0000000..05b1c7f --- /dev/null +++ b/llmfoundry/models/llama/model.py @@ -0,0 +1,35 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .mlp import LlamaMLP +from .config import LlamaConfig +from .rms_norm import LlamaRMSNorm +from .decoder import LlamaDecoderLayer + +class LlamaModel(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=None) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = self.norm(hidden_states) + return hidden_states \ No newline at end of file diff --git a/llmfoundry/models/llama/rms_norm.py b/llmfoundry/models/llama/rms_norm.py new file mode 100644 index 0000000..05c06bc --- /dev/null +++ b/llmfoundry/models/llama/rms_norm.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) diff --git a/llmfoundry/models/llama/rope.py b/llmfoundry/models/llama/rope.py new file mode 100644 index 0000000..2d9cb22 --- /dev/null +++ b/llmfoundry/models/llama/rope.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def rotate_half(x): + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=8192, base=10000): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, position_ids: torch.LongTensor): + # position_ids: [batch_size, seq_len] + inv_freq = self.inv_freq.to(device=position_ids.device) + inv_freq_expanded = inv_freq[None, None, :] # [1, 1, dim//2] + position_ids_expanded = position_ids[:, :, None].float() # [batch_size, seq_len, 1] + freqs = torch.matmul(position_ids_expanded, inv_freq_expanded) # [batch_size, seq_len, dim//2] + freqs = torch.cat([freqs, freqs], dim=-1) # [batch_size, seq_len, dim] + cos = torch.cos(freqs) + sin = torch.sin(freqs) + cos = cos.unsqueeze(1) # [batch_size, 1, seq_len, dim] + sin = sin.unsqueeze(1) # [batch_size, 1, seq_len, dim] + return cos, sin