Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
106 changes: 106 additions & 0 deletions llmfoundry/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from flash_attn import flash_attn_func
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .liger_rope import LigerRopeFunction
from .config import LlamaConfig

class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_attention_heads`: {self.num_heads})."
)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.register_buffer(
"cos_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[0],
persistent=False,
)
self.register_buffer(
"sin_cached",
self._compute_rope_embeddings(
self.max_position_embeddings,
self.head_dim,
self.rope_theta,
dtype=torch.float32,
device=self.q_proj.weight.device,
)[1],
persistent=False,
)

def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos.unsqueeze(0), sin.unsqueeze(0)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# In B S (H D)
bsz, seq_len, _ = hidden_states.size()

if position_ids is None:
position_ids = torch.arange(seq_len, device=hidden_states.device)
position_ids = repeat(position_ids, 'l -> b l', b=bsz)

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = rearrange(query_states, "b s (h d) -> b s h d", h=self.num_heads, d=self.head_dim)
key_states = rearrange(key_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)
value_states = rearrange(value_states, "b s (h d) -> b s h d", h=self.num_key_value_heads, d=self.head_dim)

# Slice off position specific rope freqs from the cached freqs
cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim]
sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim]

query_states, key_states = LigerRopeFunction.apply(
query_states,
key_states,
cos.squeeze(0),
sin.squeeze(0),
position_ids
)

attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=0.0,
causal=attention_mask is None
)

attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
return self.o_proj(attn_output)
15 changes: 15 additions & 0 deletions llmfoundry/models/llama/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass

@dataclass
class LlamaConfig:
hidden_size: int = 576
num_attention_heads: int = 9
num_key_value_heads: int = 3
num_hidden_layers: int = 30
intermediate_size: int = 1536
hidden_act: str = "silu"
rms_norm_eps: float = 1e-5
vocab_size: int = 49152
max_position_embeddings: int = 8192
rope_theta: int = 100000
tie_word_embeddings: bool = False
40 changes: 40 additions & 0 deletions llmfoundry/models/llama/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .mlp import LlamaMLP
from .config import LlamaConfig
from .rms_norm import LlamaRMSNorm
from .attention import LlamaAttention

class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.self_attn = LlamaAttention(config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states
85 changes: 85 additions & 0 deletions llmfoundry/models/llama/front_end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from .config import LlamaConfig
from .model import LlamaModel

class LlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Weight tying uses the head weights as the classifier for the token embeddings for both in and out.
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

self._init_weights()

def _init_weights(self):
"""Initialize weights for all layers."""
# Initialize embeddings
if hasattr(self.model, 'embed_tokens'):
nn.init.normal_(self.model.embed_tokens.weight, mean=0.0, std=0.041666666666666664)

# Initialize linear layers
for module in self.modules():
if isinstance(module, nn.Linear):
# Xavier/Glorot initialization for weights
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
# Zero initialization for biases
nn.init.zeros_(module.bias)

def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)

return hidden_states, self.lm_head.weight

@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 30,
temperature: float = 0.0,
) -> torch.LongTensor:
self.eval()
bsz, seq_len = input_ids.shape

position_ids = repeat(
torch.arange(seq_len, device=input_ids.device),
'l -> b l',
b=bsz
)

for _ in range(max_new_tokens):
hidden_states, classifier_weights = self.forward(input_ids, position_ids=position_ids)

# Get logits by computing hidden_states @ classifier_weights.T
next_token_logits = hidden_states[:, -1] @ classifier_weights.T

if temperature == 0:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
scaled_logits = next_token_logits / temperature
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

input_ids = torch.cat([input_ids, next_token], dim=1)
new_position_ids = position_ids[:, -1:] + 1
position_ids = torch.cat([position_ids, new_position_ids], dim=1)

return input_ids
Loading