diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index e23133f..71ff819 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -39,6 +39,8 @@ ) from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector +from llmfoundry.callbacks.text_generation_callback import TextGenerationCallback +from llmfoundry.callbacks.batch_inspection_callback import BatchInspectionCallback from llmfoundry.registry import callbacks, callbacks_with_config callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) @@ -65,6 +67,8 @@ callbacks.register('nan_monitor', func=NaNMonitor) callbacks.register('kill_loss_spike', func=KillLossSpike) callbacks.register('load_checkpoint', func=LoadCheckpoint) +callbacks.register('text_generation', func=TextGenerationCallback) +callbacks.register('batch_inspection', func=BatchInspectionCallback) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -83,4 +87,6 @@ 'CurriculumLearning', 'LossPerpVsContextLengthLogger', 'KillLossSpike', + 'TextGenerationCallback', + 'BatchInspectionCallback', ] diff --git a/llmfoundry/callbacks/batch_inspection_callback.py b/llmfoundry/callbacks/batch_inspection_callback.py new file mode 100644 index 0000000..801a79c --- /dev/null +++ b/llmfoundry/callbacks/batch_inspection_callback.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +"""Callback for inspecting batch data passed to the model's forward method.""" + +import logging +import torch +from typing import Any, Dict +from composer.core import Callback, State +from composer.loggers import Logger + +logger = logging.getLogger(__name__) + +class BatchInspectionCallback(Callback): + """Callback that logs detailed information about batches passed to the model.""" + + def __init__( + self, + log_frequency: int = 10, + sample_size: int = 3, + log_to_console: bool = True, + log_to_wandb: bool = True, + ): + """Initialize the callback. + + Args: + log_frequency: Log batch info every N batches + sample_size: Number of sample values to show from each tensor + log_to_console: Whether to print to console + log_to_wandb: Whether to log to wandb + """ + self.log_frequency = log_frequency + self.sample_size = sample_size + self.log_to_console = log_to_console + self.log_to_wandb = log_to_wandb + self.batch_count = 0 + self.original_forward = None + + def _inspect_tensor(self, tensor: torch.Tensor, name: str) -> dict[str, Any]: + """Extract detailed information from a tensor.""" + info = { + 'shape': list(tensor.shape), + 'dtype': str(tensor.dtype), + 'device': str(tensor.device), + 'requires_grad': tensor.requires_grad, + } + + # Get sample values + if tensor.numel() > 0: + flattened = tensor.flatten() + + # For input_ids and labels, show first 10 and last 10 + if name in ['input_ids', 'labels']: + first_10 = flattened[:10].tolist() if flattened.numel() >= 10 else flattened.tolist() + last_10 = flattened[-10:].tolist() if flattened.numel() >= 10 else [] + info['first_10'] = first_10 + info['last_10'] = last_10 + else: + # For other tensors, use original sampling method + sample_indices = torch.linspace(0, flattened.numel()-1, min(self.sample_size, flattened.numel()), dtype=torch.long) + samples = flattened[sample_indices].tolist() + info['samples'] = samples + + # Basic statistics for numeric tensors + if tensor.dtype in [torch.float16, torch.float32, torch.float64, torch.int32, torch.int64]: + info['min'] = tensor.min().item() + info['max'] = tensor.max().item() + info['mean'] = tensor.float().mean().item() + + return info + + def _inspect_batch(self, batch: Any, prefix: str = "") -> dict[str, Any]: + """Recursively inspect batch structure.""" + batch_info = {} + + if isinstance(batch, torch.Tensor): + return self._inspect_tensor(batch, prefix) + elif isinstance(batch, dict): + batch_info['type'] = 'dict' + batch_info['keys'] = list(batch.keys()) + batch_info['contents'] = {} + for key, value in batch.items(): + batch_info['contents'][key] = self._inspect_batch(value, f"{prefix}.{key}" if prefix else key) + elif isinstance(batch, (list, tuple)): + batch_info['type'] = type(batch).__name__ + batch_info['length'] = len(batch) + batch_info['contents'] = {} + for i, item in enumerate(batch): + batch_info['contents'][f'item_{i}'] = self._inspect_batch(item, f"{prefix}[{i}]" if prefix else f"item_{i}") + else: + batch_info = { + 'type': type(batch).__name__, + 'value': str(batch)[:100], # Truncate long strings + } + + return batch_info + + def _log_batch_info(self, batch_info: dict[str, Any], state: State, logger: Logger): + """Log batch information to console and wandb.""" + if self.log_to_console: + print(f"\n{'='*80}") + print(f"BATCH INSPECTION - Step {state.timestamp.batch}") + print(f"{'='*80}") + self._print_batch_info(batch_info) + print(f"{'='*80}\n") + + if self.log_to_wandb: + # Flatten the batch info for wandb logging + flat_info = self._flatten_batch_info(batch_info) + wandb_metrics = {} + for key, value in flat_info.items(): + if isinstance(value, (int, float, str, bool)): + wandb_metrics[f"batch_inspection/{key}"] = value + + if wandb_metrics and hasattr(logger, 'log_metrics'): + logger.log_metrics(wandb_metrics) + + def _print_batch_info(self, info: Any, indent: int = 0): + """Recursively print batch information.""" + prefix = " " * indent + + if isinstance(info, dict): + if 'type' in info and info['type'] in ['dict', 'list', 'tuple']: + print(f"{prefix}Type: {info['type']}") + if 'keys' in info: + print(f"{prefix}Keys: {info['keys']}") + if 'length' in info: + print(f"{prefix}Length: {info['length']}") + if 'contents' in info: + for key, value in info['contents'].items(): + print(f"{prefix}{key}:") + self._print_batch_info(value, indent + 1) + else: + # Tensor info + for key, value in info.items(): + if key in ['samples', 'first_10', 'last_10']: + print(f"{prefix}{key}: {value}") + else: + print(f"{prefix}{key}: {value}") + else: + print(f"{prefix}{info}") + + def _flatten_batch_info(self, info: Any, prefix: str = "") -> dict[str, Any]: + """Flatten nested batch info for wandb logging.""" + flat = {} + + if isinstance(info, dict): + if 'shape' in info: # Tensor info + flat[f"{prefix}_shape"] = str(info['shape']) + flat[f"{prefix}_dtype"] = info['dtype'] + flat[f"{prefix}_device"] = info['device'] + if 'min' in info: + flat[f"{prefix}_min"] = info['min'] + if 'max' in info: + flat[f"{prefix}_max"] = info['max'] + if 'mean' in info: + flat[f"{prefix}_mean"] = info['mean'] + elif 'contents' in info: + for key, value in info['contents'].items(): + new_prefix = f"{prefix}_{key}" if prefix else key + flat.update(self._flatten_batch_info(value, new_prefix)) + + return flat + + def _create_forward_wrapper(self, original_forward): + """Create a wrapper around the model's forward method.""" + def forward_wrapper(*args, **kwargs): + # Inspect the batch (typically the first argument) + if args and self.batch_count % self.log_frequency == 0: + batch = args[0] if len(args) > 0 else kwargs + batch_info = self._inspect_batch(batch) + + # We'll store this info to log it in the next callback event + if hasattr(self, '_current_state') and hasattr(self, '_current_logger'): + self._log_batch_info(batch_info, self._current_state, self._current_logger) + + self.batch_count += 1 + return original_forward(*args, **kwargs) + + return forward_wrapper + + def batch_start(self, state: State, logger: Logger) -> None: + """Called at the start of each batch.""" + # Store state and logger for use in forward wrapper + self._current_state = state + self._current_logger = logger + + # Wrap the model's forward method if not already wrapped + if self.original_forward is None: + model = state.model + if hasattr(model, 'forward'): + self.original_forward = model.forward + model.forward = self._create_forward_wrapper(self.original_forward) + + def batch_end(self, state: State, logger: Logger) -> None: + """Called at the end of each batch.""" + # Clean up references + self._current_state = None + self._current_logger = None + + def fit_end(self, state: State, logger: Logger) -> None: + """Restore original forward method when training ends.""" + if self.original_forward is not None: + state.model.forward = self.original_forward + self.original_forward = None diff --git a/text_generation_callback.py b/llmfoundry/callbacks/text_generation_callback.py similarity index 58% rename from text_generation_callback.py rename to llmfoundry/callbacks/text_generation_callback.py index c5feb03..bd848a4 100644 --- a/text_generation_callback.py +++ b/llmfoundry/callbacks/text_generation_callback.py @@ -70,17 +70,50 @@ def _generate_and_log_text(self, state: State, logger: Logger, event_name: str): "prompt": prompt, "error": str(e), } - - if self.log_to_wandb and hasattr(logger, 'log_metrics'): + if self.log_to_wandb: + # Prepare table data for all successful generations + table_data = [] for key, value in generated_texts.items(): if "error" not in value: - logger.log_metrics({ - f"generation/{event_name}/{key}/prompt": str(value["prompt"]), - f"generation/{event_name}/{key}/text": str(value["generated"]), - }) - print(f"WandB Logged: {event_name} - {key}") - print(f" Prompt: {value['prompt']}") - print(f" Generated: {value['generated']}") + table_data.append([key, value["prompt"], value["generated"]]) + + if table_data: + # Log to W&B directly using proper wandb.Table + import wandb + from composer.loggers import WandBLogger + + # Find the WandBLogger destination + wandb_logger = None + for destination in logger.destinations: + if isinstance(destination, WandBLogger): + wandb_logger = destination + break + + if wandb_logger and hasattr(wandb_logger, '_run') and wandb_logger._run: + # Create proper W&B table + table = wandb.Table( # pyright: ignore[reportAttributeAccessIssue] + columns=["prompt_id", "prompt", "generated_text"], + data=table_data, + ) + # Log the table directly to W&B with step-specific name + wandb_logger._run.log({ + f"generation/{event_name}_step_{state.timestamp.batch.value}": table, + }, step=state.timestamp.batch.value) + print(f"WandB Logged: {event_name}") + print(f" Number of generations: {len(table_data)}") + print() + else: + # Fallback to Composer's log_table method + for destination in logger.destinations: + if hasattr(destination, 'log_table'): + destination.log_table( + columns=["prompt_id", "prompt", "generated_text"], + rows=table_data, + name=f"generation/{event_name}_step_{state.timestamp.batch.value}", + step=state.timestamp.batch.value, + ) + print(f"WandB Logged (fallback): {event_name}") + print(f" Number of generations: {len(table_data)}") print() except Exception as e: @@ -93,6 +126,3 @@ def fit_start(self, state: State, logger: Logger) -> None: def eval_start(self, state: State, logger: Logger) -> None: """Generate text before evaluation starts.""" self._generate_and_log_text(state, logger, "BEFORE_EVAL") - -from llmfoundry.registry import callbacks -callbacks.register('text_generation', func=TextGenerationCallback) \ No newline at end of file diff --git a/llmfoundry/models/llama/custom_model.py b/llmfoundry/models/llama/custom_model.py index 4b72df9..ed336ca 100644 --- a/llmfoundry/models/llama/custom_model.py +++ b/llmfoundry/models/llama/custom_model.py @@ -1,3 +1,6 @@ +# TODO: Clean up multiple LlamaEmbeddings in LlamaModel. +# TODO: Implement KV cache. + # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # @@ -28,10 +31,16 @@ from torch import nn from transformers.models.llama.configuration_llama import LlamaConfig from transformers import PreTrainedTokenizerBase, PreTrainedModel +from transformers.utils import is_flash_attn_2_available from llmfoundry.data.finetuning.collator import CROSS_ENTROPY_IGNORE_INDEX from transformers import AutoModelForCausalLM +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary + SMOLLM2_CONFIG_135M = LlamaConfig( attention_bias = False, attention_dropout = 0.0, @@ -59,8 +68,195 @@ transformers_version = "4.55.0.dev0", use_cache = True, vocab_size = 49152, + _attn_implementation = "flash_attention_2", ) +# Modernbert unpadding and repadding +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, nheads, headdim) + apply_rotary( + x, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return x + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + apply_rotary( + do, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + x, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """Arguments: + x: (total_nnz, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ # noqa: D205 + return ApplyRotaryEmbUnpad.apply(x, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """The rotary position embeddings applied directly to unpadded sequences.""" + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache will be recomputed during the forward pass. + """ # noqa: D205 + super().__init__(dim=dim, base=base, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding *inplace* to x. + x: (total_nnz, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ # noqa: D205 + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype) + + x = apply_rotary_unpadded( + x, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() @@ -128,6 +324,101 @@ def __init__(self, config: LlamaConfig): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +def flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + is_causal: bool, + scaling: float, + enable_gqa: bool, + input_shape: tuple[int, int], + hidden_shape: tuple[int, int], + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + cu_seqlens: torch.Tensor, + max_seqlen: int, + rotary_emb: Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding], + target_dtype: torch.dtype = torch.bfloat16, + **kwargs: Any, +) -> tuple[torch.Tensor]: + # (total_seqlen, nheads, headdim) + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + query_states = rotary_emb(query_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + key_states = rotary_emb(key_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + convert_dtype = query_states.dtype not in (torch.float16, torch.bfloat16) + + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = query_states.dtype + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + deterministic=False, + causal=is_causal, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + deterministic=False, + causal=is_causal, + ) + total_tokens = attn.shape[0] + hidden_size = attn.shape[1] * attn.shape[2] # num_heads * head_dim + return attn.view(total_tokens, hidden_size) + +def sdpa_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + is_causal: bool, + scaling: float, + enable_gqa: bool, + input_shape: tuple[int, int], + hidden_shape: tuple[int, int], + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + **kwargs: Any, +) -> torch.Tensor: + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn_output = nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=is_causal, + scale=scaling, enable_gqa=enable_gqa).transpose(1,2) + + attn_output = attn_output.reshape(*input_shape, -1) + return attn_output + +LLAMA_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "sdpa": sdpa_attention_forward, +} + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: int): @@ -148,23 +439,27 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + rotary_emb: Optional[Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding]] = None, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - attn_output = nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, is_causal=self.is_causal, - scale=self.scaling, enable_gqa=True).transpose(1,2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + attn_output = LLAMA_ATTENTION_FUNCTION[self.config._attn_implementation]( + query_states, key_states, value_states, self.is_causal, self.scaling, + True, input_shape, hidden_shape, position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=rotary_emb, + **kwargs, + ) - attn_output = attn_output.reshape(*input_shape, -1) attn_output = self.o_proj(attn_output) return attn_output @@ -181,6 +476,10 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + rotary_emb: Optional[Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding]] = None, **kwargs: Any, ) -> tuple[torch.Tensor]: @@ -189,6 +488,10 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, + attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -273,31 +576,69 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(config=config) self.can_generate = True self.tie_weights() def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[tuple] = None, use_cache: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs: Any, ): - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + batch_size, seq_len = input_ids.shape[:2] + repad = False + cu_seqlens = None + max_seqlen = None + indices = None + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is None: + attention_mask = (input_ids != self.config.eos_token_id).long() + if inputs_embeds is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, + ) + else: + attention_mask = None + inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)) + + # For flash attention, we don't need position_embeddings since ModernBertUnpaddedRotaryEmbedding handles it + if self.config._attn_implementation == "flash_attention_2": + device = hidden_states.device if hidden_states is not None else input_ids.device + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.config.head_dim, + base=self.config.rope_theta, + max_seqlen=self.config.max_position_embeddings, + device=device, + ) + position_embeddings = None + else: + device = hidden_states.device if hidden_states is not None else input_ids.device + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + position_embeddings = self.rotary_emb(hidden_states, torch.arange(seq_len, device=device).unsqueeze(0)) for decoder_layer in self.layers[:self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, + attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=self.rotary_emb, **kwargs, ) - return self.lm_head(self.norm(hidden_states)) + hidden_states = self.norm(hidden_states) + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len, + ) + return self.lm_head(hidden_states) @classmethod def from_pretrained(cls, model_type: str, device_map: str = "auto", torch_dtype: torch.dtype = torch.bfloat16): @@ -372,16 +713,6 @@ def __init__( ) def forward(self, batch: dict[str, Any]) -> torch.Tensor: - # input_ids = batch['input_ids'] - - # # Create attention mask if not provided (mark non-padding tokens as 1) - # attention_mask = batch.get('attention_mask') - # if attention_mask is None: - # # Assume padding token is 0 (EOS token based on your config) - # attention_mask = (input_ids != 0).long() - # print('attention_mask:', attention_mask) - # print('sum of attention_mask:', attention_mask.sum(dim=1)) - # return self.model(input_ids=input_ids, attention_mask=attention_mask) return self.model(input_ids=batch['input_ids']) def loss(self, outputs: torch.Tensor, batch: dict[str, Any]) -> torch.Tensor: diff --git a/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml b/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml index 60a14c3..6931e3e 100644 --- a/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml +++ b/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml @@ -1,6 +1,6 @@ variables: data_local: datasets/tmp # datasets/c4_small - data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/1k/ # If blank, files must be present in data_local + data_remote: hf://datasets/LocalResearchGroup/split-NuminaMath-CoT/tokenized/numina/full/ # If blank, files must be present in data_local tokenizer_name: HuggingFaceTB/SmolLM2-135M global_seed: 42 max_seq_len: 2048 @@ -9,14 +9,12 @@ variables: max_seq_len: ${variables.max_seq_len} run_name: ${variables.run_name} -# Model - Using custom SmolLM2-135M model model: name: custom-smollm2-135m model_type: smollm2-135m - pretrained: false + pretrained: true pretrained_model_name_or_path: HuggingFaceTB/SmolLM2-135M -# Tokenizer tokenizer: name: ${variables.tokenizer_name} kwargs: @@ -24,9 +22,9 @@ tokenizer: loggers: wandb: - project: 'smollm2-135m-training-avelina' + project: 'torch-compile' entity: 'local-research-group' - name: 'smollm2-135m-training-avelina-1k' + name: 'numina-dynamic_epilogue_fusion' # Data loaders train_loader: @@ -40,6 +38,14 @@ train_loader: shuffle_seed: ${variables.global_seed} deorder_only_format: true eos_token_id: 0 + # Sequence packing configuration + # sequence_packing: true + # micro_batch_size: 4 # Should match device_train_microbatch_size + # packing_buffer_size: 20 # 5 * device_batch_size + # packing_prefetch_factor: 5 + # Optional: batch size warmup + # batch_size_warmup_min_size: 2 + # batch_size_warmup_tokens: 1000ba drop_last: true num_workers: 4 @@ -78,10 +84,10 @@ algorithms: clipping_threshold: 1.0 # Training duration and evaluation -max_duration: 1000ba -eval_interval: 100ba +max_duration: 1ba +eval_interval: 1ba eval_first: false -eval_subset_num_batches: 5 +eval_subset_num_batches: 0 save_overwrite: true save_interval: 0ba # Disable checkpointing @@ -89,7 +95,7 @@ save_interval: 0ba # Disable checkpointing seed: ${variables.global_seed} device_eval_batch_size: 2 device_train_microbatch_size: 4 -global_train_batch_size: 4 +global_train_batch_size: 8 precision: amp_bf16 # FSDP configuration @@ -101,6 +107,14 @@ fsdp_config: activation_cpu_offload: false limit_all_gathers: true +# Torch compile configuration +# compile_config: +# dynamic: true # Enable dynamic shapes +# options: +# epilogue_fusion: true # Fuses pointwise ops into templates (requires max-autotune) +# max_autotune: true # Enables automatic tuning of fusion patterns + # coordinate_descent_tuning: true # Enables coordinate descent tuning for optimization + # Logging progress_bar: true log_to_console: true @@ -112,11 +126,18 @@ callbacks: lr_monitor: {} memory_monitor: {} runtime_estimator: {} - text_generation: - prompts: - - "The future of artificial intelligence is" - - "In a world where technology" - - "def main():" - - "Gravity is" - max_new_tokens: 50 - temperature: 0. \ No newline at end of file + # batch_inspection: + # log_frequency: 1 + # sample_size: 1 + # log_to_console: true + # log_to_wandb: true + # text_generation: + # prompts: + # - "The future of artificial intelligence is" + # - "In a world where technology" + # - "def main():" + # - "Gravity is" + # - "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Local Research Group<|im_end|>\n<|im_start|>user\nWhat is 1+1?<|im_end|>\n<|im_start|>assistant\n" + # max_new_tokens: 50 + # temperature: 0.7 + # log_to_wandb: true \ No newline at end of file diff --git a/simple_train_smollm2.py b/simple_train_smollm2.py index bd0213c..17b805b 100644 --- a/simple_train_smollm2.py +++ b/simple_train_smollm2.py @@ -14,10 +14,6 @@ from llmfoundry.command_utils.train import train from omegaconf import OmegaConf -# Import callbacks to register them -import text_generation_callback # type: ignore -import batch_inspection_callback # type: ignore - # Set up logging logging.basicConfig( level=logging.INFO,