Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions llmfoundry/models/llama/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.ops.triton.rotary import apply_rotary

from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP

SMOLLM2_CONFIG_135M = LlamaConfig(
attention_bias = False,
attention_dropout = 0.0,
Expand Down Expand Up @@ -69,6 +73,9 @@
use_cache = True,
vocab_size = 49152,
_attn_implementation = "sdpa",
_use_liger_rms_norm = False,
_use_liger_fused_crossentropy = False,
_use_liger_mlp = False,
)

# Modernbert unpadding and repadding
Expand Down Expand Up @@ -468,9 +475,10 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_cls = LigerRMSNorm if config._use_liger_rms_norm else LlamaRMSNorm
self.mlp = LigerSwiGLUMLP(config) if config._use_liger_mlp else LlamaMLP(config)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -572,10 +580,13 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
use_liger_rms_norm = getattr(config, '_use_liger_rms_norm', False)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_cls = LigerRMSNorm if use_liger_rms_norm else LlamaRMSNorm
self.norm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self._use_liger_fused_crossentropy = getattr(config, '_use_liger_fused_crossentropy', False)
self.can_generate = True
self.tie_weights()

Expand Down Expand Up @@ -638,10 +649,11 @@ def forward(
hidden_states = _pad_modernbert_output(
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len,
)
return self.lm_head(hidden_states)
if self._use_liger_fused_crossentropy: return hidden_states
else: return self.lm_head(hidden_states)

@classmethod
def from_pretrained(cls, model_type: str, device_map: str = "auto", torch_dtype: torch.dtype = torch.bfloat16):
def from_pretrained(cls, model_type: str, device_map: str = "auto", torch_dtype: torch.dtype = torch.bfloat16, use_liger_rms_norm: bool = False, use_liger_fused_crossentropy: bool = False, use_liger_mlp: bool = False):
if model_type == "smollm2-135m":
checkpoint = "HuggingFaceTB/SmolLM2-135M"
config = SMOLLM2_CONFIG_135M
Expand All @@ -650,6 +662,9 @@ def from_pretrained(cls, model_type: str, device_map: str = "auto", torch_dtype:
raise NotImplementedError("SmolLM2-1.7B config not yet implemented")
else:
raise ValueError(f"Model type {model_type} not supported")
config._use_liger_rms_norm = use_liger_rms_norm
config._use_liger_fused_crossentropy = use_liger_fused_crossentropy
config._use_liger_mlp = use_liger_mlp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_hf = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=device_map, torch_dtype=torch_dtype).to(device)
sd_hf = model_hf.state_dict()
Expand Down Expand Up @@ -694,35 +709,49 @@ def get_decoder(self): return self
class CustomLlamaModel(BaseHuggingFaceModel):
"""Custom Llama model wrapper for LLM Foundry compatibility."""

_use_liger_rms_norm: bool = False
_use_liger_fused_crossentropy: bool = False
_use_liger_mlp: bool = False

def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
model_type: str = "smollm2-135m",
pretrained: bool = True,
use_liger_rms_norm: bool = False,
use_liger_fused_crossentropy: bool = False,
use_liger_mlp: bool = False,
peft_config: Optional[dict[str, Any]] = None,
pretrained_model_name_or_path: str = "HuggingFaceTB/SmolLM2-135M",
**kwargs: Any,
):
CustomLlamaModel._use_liger_rms_norm = use_liger_rms_norm
CustomLlamaModel._use_liger_fused_crossentropy = use_liger_fused_crossentropy
CustomLlamaModel._use_liger_mlp = use_liger_mlp
self._use_liger_fused_crossentropy = use_liger_fused_crossentropy
super().__init__(
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer=tokenizer,
pretrained=pretrained,
peft_config=peft_config,
shift_labels=True,
**kwargs,
)
)
if use_liger_fused_crossentropy: self.liger_loss_fn = LigerFusedLinearCrossEntropyLoss(ignore_index=CROSS_ENTROPY_IGNORE_INDEX, reduction='mean')

def forward(self, batch: dict[str, Any]) -> torch.Tensor:
return self.model(input_ids=batch['input_ids'])

def loss(self, outputs: torch.Tensor, batch: dict[str, Any]) -> torch.Tensor:
targets = torch.roll(batch['labels'], shifts=-1, dims=1)
targets[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
return F.cross_entropy(
outputs.flatten(0, -2),
targets.flatten(),
ignore_index=CROSS_ENTROPY_IGNORE_INDEX,
)
targets_flat = targets.flatten()
outputs_flat = outputs.flatten(0, -2)

if self._use_liger_fused_crossentropy:
return self.liger_loss_fn(self.model.lm_head.weight, outputs_flat, targets_flat)
else:
return F.cross_entropy(outputs_flat, targets_flat, ignore_index=CROSS_ENTROPY_IGNORE_INDEX)

def generate(
self,
Expand Down Expand Up @@ -772,10 +801,23 @@ def build_inner_model(
**kwargs: Any,
) -> Union[PreTrainedModel, 'PeftModel']:
"""Build your custom model instead of using AutoModelForCausalLM."""
use_liger_rms_norm = cls._use_liger_rms_norm
use_liger_fused_crossentropy = cls._use_liger_fused_crossentropy
use_liger_mlp = cls._use_liger_mlp
if pretrained:
model = LlamaModel.from_pretrained("smollm2-135m")
model = LlamaModel.from_pretrained("smollm2-135m", use_liger_rms_norm=use_liger_rms_norm, use_liger_fused_crossentropy=use_liger_fused_crossentropy, use_liger_mlp=use_liger_mlp)
model.config._use_liger_fused_crossentropy = use_liger_fused_crossentropy
model.config._use_liger_mlp = use_liger_mlp
else:
model = LlamaModel(SMOLLM2_CONFIG_135M)
from copy import deepcopy
config = deepcopy(SMOLLM2_CONFIG_135M)
config._use_liger_rms_norm = use_liger_rms_norm
config._use_liger_fused_crossentropy = use_liger_fused_crossentropy
config._use_liger_mlp = use_liger_mlp
if config_overrides:
for key, value in config_overrides.items():
setattr(config, key, value)
model = LlamaModel(config)

if pretrained_lora_id_or_path is not None:
from composer.models.huggingface import peft_installed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ dependencies = [
"aim>=3.26.0,<4",
"zstd>=1.5.6.1,!=1.5.6.2",
"math_verify>=0.6.0",
"liger-kernel",
"numba>=0.62.1",
]

Expand Down
23 changes: 13 additions & 10 deletions scripts/train/yamls/pretrain/custom_smollm2-135m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ variables:
data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/full/
tokenizer_name: HuggingFaceTB/SmolLM2-135M
global_seed: 42
max_seq_len: 1024
max_seq_len: 8192
run_name: custom_smollm2_135m_training
device_microbatch_size: 4
global_train_batch_size: 8
device_microbatch_size: 2
global_train_batch_size: 4

max_seq_len: ${variables.max_seq_len}
run_name: ${variables.run_name}
Expand All @@ -16,6 +16,9 @@ model:
model_type: smollm2-135m
pretrained: false
pretrained_model_name_or_path: HuggingFaceTB/SmolLM2-135M
use_liger_rms_norm: true
use_liger_fused_crossentropy: true
use_liger_mlp: true

tokenizer:
name: ${variables.tokenizer_name}
Expand All @@ -24,9 +27,9 @@ tokenizer:

loggers:
wandb:
project: 'sequence-packing'
project: 'liger-kernel'
entity: 'local-research-group'
name: 'avelina-sequence-packing-2000ba-max-seq-len-1024_with-eval'
name: 'liger-all-compile(dynamic_f)_bs-2'

# Data loaders
train_loader:
Expand Down Expand Up @@ -119,11 +122,11 @@ fsdp_config:

# Torch compile configuration
compile_config:
dynamic: true # Enable dynamic shapes
options:
epilogue_fusion: true # Fuses pointwise ops into templates (requires max-autotune)
max_autotune: true # Enables automatic tuning of fusion patterns
coordinate_descent_tuning: true # Enables coordinate descent tuning for optimization
dynamic: false # Disable dynamic shapes (required for Triton kernels with liger_kernel)
# options:
# epilogue_fusion: true # Fuses pointwise ops into templates (requires max-autotune)
# max_autotune: true # Enables automatic tuning of fusion patterns
# coordinate_descent_tuning: true # Enables coordinate descent tuning for optimization

# Logging
progress_bar: true
Expand Down
Loading