diff --git a/custom_model_training.py b/custom_model_training.py new file mode 100644 index 0000000..37747dd --- /dev/null +++ b/custom_model_training.py @@ -0,0 +1,128 @@ +import os +import modal +import sys +from modal import Image, App, Secret, Volume + +import pathlib, datetime + +PYTHON_PATH = "/opt/conda/envs/llm-foundry/bin/python" + +# command line arguments +TRAINING_GPU = os.environ.get("MODAL_GPU", "L4") +TRAIN_YAML = os.environ.get("TRAIN_YAML", "") +IS_PEFT = os.environ.get("IS_PEFT", "True") +IS_PEFT = IS_PEFT in ("True", "true") + +OUTPUT_PRECISION = os.environ.get("OUTPUT_PRECISION", "bf16") + +# defaults --- make sure your Modal Volumes are titled accordingly +DATASET_BASE_PATH = "/datasets" +DATASETS_VOLUME = Volume.from_name("lrg-datasets", create_if_missing=True) +DATASETS_VOLUME_MOUNT_PATH = pathlib.Path("/datasets") +MODEL_CHECKPOINT_VOLUME = Volume.from_name("lrg-model-checkpoints", create_if_missing=True) +MODEL_CHECKPOINT_VOLUME_MOUNT_PATH = pathlib.Path("/model-checkpoints") + +app = App("custom-llama-training") + +# Build image from local Dockerfile +image = Image.from_dockerfile("Dockerfile", gpu='L4') +image = image.add_local_file(TRAIN_YAML, f"/llm-foundry/scripts/train/yamls/pretrain/{TRAIN_YAML}") +# image = image.add_local_file("train.py", "/llm-foundry/llmfoundry/command_utils/train.py") + + +@app.function(gpu=TRAINING_GPU, image=image, timeout=12*3600, secrets=[Secret.from_name("LRG")], # pyright: ignore[reportUntypedFunctionDecorator] + volumes={MODEL_CHECKPOINT_VOLUME_MOUNT_PATH: MODEL_CHECKPOINT_VOLUME, + DATASETS_VOLUME_MOUNT_PATH: DATASETS_VOLUME}, + max_containers=1) +def _train(yaml_path): + import os + import sys + import logging + from pathlib import Path + + project_root = Path(__file__).parent + sys.path.insert(0, str(project_root)) + + from llmfoundry.models.llama.register import register_custom_llama_model + from llmfoundry.command_utils.train import train + from omegaconf import OmegaConf + + # import text_generation_callback # type: ignore + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + ) + logger = logging.getLogger(__name__) + + """Main PEFT training function.""" + logger.info("Registering custom SmolLM2-135M model...") + register_custom_llama_model() + logger.info("Custom model registered successfully!") + + config_path = f"scripts/train/yamls/pretrain/{yaml_path}" + logger.info(f"Loading configuration from: {config_path}") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + config = OmegaConf.load(config_path) + + save_folder = "lrg-model-checkpoints/hf_smollm2_135m_peft" + config.save_folder = save_folder + os.makedirs(save_folder, exist_ok=True) + logger.info(f"PEFT model checkpoints will be saved to: {save_folder}") + + dataset_local = config.variables.data_local + dataset_remote = getattr(config.variables, 'data_remote', None) + if dataset_remote and str(dataset_remote).strip(): + os.makedirs(dataset_local, exist_ok=True) + logger.info( + f"Streaming dataset from remote: {dataset_remote} with local cache: {dataset_local}") + else: + if not os.path.exists(dataset_local): + logger.warning(f"Dataset not found at: {dataset_local}") + return + logger.info(f"Using local dataset at: {dataset_local}") + + if IS_PEFT and hasattr(config.model, 'peft_config') and config.model.peft_config: + peft_config = config.model.peft_config + logger.info("PEFT Configuration:") + logger.info(f" - Type: {peft_config.peft_type}") + logger.info(f" - Rank (r): {peft_config.r}") + logger.info(f" - Alpha: {peft_config.lora_alpha}") + logger.info(f" - Dropout: {peft_config.lora_dropout}") + logger.info(f" - Target modules: {peft_config.target_modules}") + logger.info(f" - Use RSLora: {peft_config.get('use_rslora', False)}") + logger.info(f" - Use DoRA: {peft_config.get('use_dora', False)}") + + logger.info("Starting PEFT training...") + try: + trainer = train(config) + logger.info("PEFT training completed successfully!") + logger.info(f"PEFT adapters saved to: {save_folder}") + del trainer + return "Training completed successfully" + except Exception as e: + logger.error(f"PEFT training failed: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + else: + logger.warning("No PEFT configuration found!") + logger.info("Starting Full finetuning training...") + try: + trainer = train(config) + logger.info("Full training completed successfully!") + del trainer + return "Training completed successfully" + except Exception as e: + logger.error(f"Full training failed: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + return + +@app.local_entrypoint() +def main(): + _train.remote(yaml_path=TRAIN_YAML) \ No newline at end of file diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index e23133f..a69756c 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -39,6 +39,9 @@ ) from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector +from llmfoundry.callbacks.text_generation_callback import TextGenerationCallback +from llmfoundry.callbacks.batch_inspection_callback import BatchInspectionCallback +from llmfoundry.callbacks.packing_efficiency_callback import PackingEfficiency from llmfoundry.registry import callbacks, callbacks_with_config callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor) @@ -65,6 +68,9 @@ callbacks.register('nan_monitor', func=NaNMonitor) callbacks.register('kill_loss_spike', func=KillLossSpike) callbacks.register('load_checkpoint', func=LoadCheckpoint) +callbacks.register('text_generation', func=TextGenerationCallback) +callbacks.register('batch_inspection', func=BatchInspectionCallback) +callbacks.register('packing_efficiency', func=PackingEfficiency) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -83,4 +89,7 @@ 'CurriculumLearning', 'LossPerpVsContextLengthLogger', 'KillLossSpike', + 'TextGenerationCallback', + 'BatchInspectionCallback', + 'PackingEfficiency', ] diff --git a/llmfoundry/callbacks/batch_inspection_callback.py b/llmfoundry/callbacks/batch_inspection_callback.py new file mode 100644 index 0000000..47ad6e4 --- /dev/null +++ b/llmfoundry/callbacks/batch_inspection_callback.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +"""Callback for inspecting batch data passed to the model's forward method.""" + +import logging +import torch +from typing import Any, Dict +from composer.core import Callback, State +from composer.loggers import Logger + +logger = logging.getLogger(__name__) + +class BatchInspectionCallback(Callback): + """Callback that logs detailed information about batches passed to the model.""" + + def __init__( + self, + log_frequency: int = 10, + sample_size: int = 3, + log_to_console: bool = True, + log_to_wandb: bool = True, + ): + """Initialize the callback. + + Args: + log_frequency: Log batch info every N batches + sample_size: Number of sample values to show from each tensor + log_to_console: Whether to print to console + log_to_wandb: Whether to log to wandb + """ + self.log_frequency = log_frequency + self.sample_size = sample_size + self.log_to_console = log_to_console + self.log_to_wandb = log_to_wandb + self.batch_count = 0 + self.original_forward = None + + def _inspect_tensor(self, tensor: torch.Tensor, name: str) -> dict[str, Any]: + """Extract detailed information from a tensor.""" + info = { + 'shape': list(tensor.shape), + 'dtype': str(tensor.dtype), + 'device': str(tensor.device), + 'requires_grad': tensor.requires_grad, + } + + # Get sample values + if tensor.numel() > 0: + flattened = tensor.flatten() + + # For input_ids and labels, show first 10 and last 10 + if name in ['input_ids', 'labels']: + first_10 = flattened[:10].tolist() if flattened.numel() >= 10 else flattened.tolist() + last_10 = flattened[-10:].tolist() if flattened.numel() >= 10 else [] + info['first_10'] = first_10 + info['last_10'] = last_10 + elif name in ['attention_mask']: + info['sum'] = tensor.sum().item() + info['numel'] = tensor.numel() + else: + # For other tensors, use original sampling method + sample_indices = torch.linspace(0, flattened.numel()-1, min(self.sample_size, flattened.numel()), dtype=torch.long) + samples = flattened[sample_indices].tolist() + info['samples'] = samples + + # Basic statistics for numeric tensors + if tensor.dtype in [torch.float16, torch.float32, torch.float64, torch.int32, torch.int64]: + info['min'] = tensor.min().item() + info['max'] = tensor.max().item() + info['mean'] = tensor.float().mean().item() + + return info + + def _inspect_batch(self, batch: Any, prefix: str = "") -> dict[str, Any]: + """Recursively inspect batch structure.""" + batch_info = {} + + if isinstance(batch, torch.Tensor): + return self._inspect_tensor(batch, prefix) + elif isinstance(batch, dict): + batch_info['type'] = 'dict' + batch_info['keys'] = list(batch.keys()) + batch_info['contents'] = {} + for key, value in batch.items(): + batch_info['contents'][key] = self._inspect_batch(value, f"{prefix}.{key}" if prefix else key) + elif isinstance(batch, (list, tuple)): + batch_info['type'] = type(batch).__name__ + batch_info['length'] = len(batch) + batch_info['contents'] = {} + for i, item in enumerate(batch): + batch_info['contents'][f'item_{i}'] = self._inspect_batch(item, f"{prefix}[{i}]" if prefix else f"item_{i}") + else: + batch_info = { + 'type': type(batch).__name__, + 'value': str(batch)[:100], # Truncate long strings + } + + return batch_info + + def _log_batch_info(self, batch_info: dict[str, Any], state: State, logger: Logger): + """Log batch information to console and wandb.""" + if self.log_to_console: + print(f"\n{'='*80}") + if state and state.timestamp: + step_info = f"Step {state.timestamp.batch}" + else: + step_info = "Evaluation" + print(f"BATCH INSPECTION - {step_info}") + print(f"{'='*80}") + self._print_batch_info(batch_info) + print(f"{'='*80}\n") + + if self.log_to_wandb: + # Flatten the batch info for wandb logging + flat_info = self._flatten_batch_info(batch_info) + wandb_metrics = {} + for key, value in flat_info.items(): + if isinstance(value, (int, float, str, bool)): + wandb_metrics[f"batch_inspection/{key}"] = value + + if wandb_metrics and hasattr(logger, 'log_metrics'): + logger.log_metrics(wandb_metrics) + + def _print_batch_info(self, info: Any, indent: int = 0): + """Recursively print batch information.""" + prefix = " " * indent + + if isinstance(info, dict): + if 'type' in info and info['type'] in ['dict', 'list', 'tuple']: + print(f"{prefix}Type: {info['type']}") + if 'keys' in info: + print(f"{prefix}Keys: {info['keys']}") + if 'length' in info: + print(f"{prefix}Length: {info['length']}") + if 'contents' in info: + for key, value in info['contents'].items(): + print(f"{prefix}{key}:") + self._print_batch_info(value, indent + 1) + else: + # Tensor info + for key, value in info.items(): + if key in ['samples', 'first_10', 'last_10']: + print(f"{prefix}{key}: {value}") + else: + print(f"{prefix}{key}: {value}") + else: + print(f"{prefix}{info}") + + def _flatten_batch_info(self, info: Any, prefix: str = "") -> dict[str, Any]: + """Flatten nested batch info for wandb logging.""" + flat = {} + + if isinstance(info, dict): + if 'shape' in info: # Tensor info + flat[f"{prefix}_shape"] = str(info['shape']) + flat[f"{prefix}_dtype"] = info['dtype'] + flat[f"{prefix}_device"] = info['device'] + if 'min' in info: + flat[f"{prefix}_min"] = info['min'] + if 'max' in info: + flat[f"{prefix}_max"] = info['max'] + if 'mean' in info: + flat[f"{prefix}_mean"] = info['mean'] + elif 'contents' in info: + for key, value in info['contents'].items(): + new_prefix = f"{prefix}_{key}" if prefix else key + flat.update(self._flatten_batch_info(value, new_prefix)) + + return flat + + def _create_forward_wrapper(self, original_forward): + """Create a wrapper around the model's forward method.""" + def forward_wrapper(*args, **kwargs): + # Inspect the batch (typically the first argument) + if args and self.batch_count % self.log_frequency == 0: + batch = args[0] if len(args) > 0 else kwargs + batch_info = self._inspect_batch(batch) + + # We'll store this info to log it in the next callback event + if hasattr(self, '_current_state') and hasattr(self, '_current_logger'): + self._log_batch_info(batch_info, self._current_state, self._current_logger) + + self.batch_count += 1 + return original_forward(*args, **kwargs) + + return forward_wrapper + + def batch_start(self, state: State, logger: Logger) -> None: + """Called at the start of each batch.""" + # Store state and logger for use in forward wrapper + self._current_state = state + self._current_logger = logger + + # Wrap the model's forward method if not already wrapped + if self.original_forward is None: + model = state.model + if hasattr(model, 'forward'): + self.original_forward = model.forward + model.forward = self._create_forward_wrapper(self.original_forward) + + def batch_end(self, state: State, logger: Logger) -> None: + """Called at the end of each batch.""" + # Clean up references + self._current_state = None + self._current_logger = None + + def fit_end(self, state: State, logger: Logger) -> None: + """Restore original forward method when training ends.""" + if self.original_forward is not None: + state.model.forward = self.original_forward + self.original_forward = None diff --git a/llmfoundry/callbacks/packing_efficiency_callback.py b/llmfoundry/callbacks/packing_efficiency_callback.py new file mode 100644 index 0000000..107bafc --- /dev/null +++ b/llmfoundry/callbacks/packing_efficiency_callback.py @@ -0,0 +1,27 @@ +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + +from composer.core import Callback, State +from composer.loggers import Logger + +__all__ = ["PackingEfficiency"] + + +class PackingEfficiency(Callback): + """Records the packing efficiency for each batch.""" + + def __init__(self, log_interval: int = 100): + self.log_interval = log_interval + + def after_dataloader(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.log_interval != 0: + return + logger.log_metrics( + { + "trainer/packing_efficiency": self._packing_efficiency(state), + }, + ) + + def _packing_efficiency(self, state: State) -> float: + return state.batch["attention_mask"].sum().item() / state.batch["attention_mask"].numel() + diff --git a/text_generation_callback.py b/llmfoundry/callbacks/text_generation_callback.py similarity index 58% rename from text_generation_callback.py rename to llmfoundry/callbacks/text_generation_callback.py index c5feb03..bd848a4 100644 --- a/text_generation_callback.py +++ b/llmfoundry/callbacks/text_generation_callback.py @@ -70,17 +70,50 @@ def _generate_and_log_text(self, state: State, logger: Logger, event_name: str): "prompt": prompt, "error": str(e), } - - if self.log_to_wandb and hasattr(logger, 'log_metrics'): + if self.log_to_wandb: + # Prepare table data for all successful generations + table_data = [] for key, value in generated_texts.items(): if "error" not in value: - logger.log_metrics({ - f"generation/{event_name}/{key}/prompt": str(value["prompt"]), - f"generation/{event_name}/{key}/text": str(value["generated"]), - }) - print(f"WandB Logged: {event_name} - {key}") - print(f" Prompt: {value['prompt']}") - print(f" Generated: {value['generated']}") + table_data.append([key, value["prompt"], value["generated"]]) + + if table_data: + # Log to W&B directly using proper wandb.Table + import wandb + from composer.loggers import WandBLogger + + # Find the WandBLogger destination + wandb_logger = None + for destination in logger.destinations: + if isinstance(destination, WandBLogger): + wandb_logger = destination + break + + if wandb_logger and hasattr(wandb_logger, '_run') and wandb_logger._run: + # Create proper W&B table + table = wandb.Table( # pyright: ignore[reportAttributeAccessIssue] + columns=["prompt_id", "prompt", "generated_text"], + data=table_data, + ) + # Log the table directly to W&B with step-specific name + wandb_logger._run.log({ + f"generation/{event_name}_step_{state.timestamp.batch.value}": table, + }, step=state.timestamp.batch.value) + print(f"WandB Logged: {event_name}") + print(f" Number of generations: {len(table_data)}") + print() + else: + # Fallback to Composer's log_table method + for destination in logger.destinations: + if hasattr(destination, 'log_table'): + destination.log_table( + columns=["prompt_id", "prompt", "generated_text"], + rows=table_data, + name=f"generation/{event_name}_step_{state.timestamp.batch.value}", + step=state.timestamp.batch.value, + ) + print(f"WandB Logged (fallback): {event_name}") + print(f" Number of generations: {len(table_data)}") print() except Exception as e: @@ -93,6 +126,3 @@ def fit_start(self, state: State, logger: Logger) -> None: def eval_start(self, state: State, logger: Logger) -> None: """Generate text before evaluation starts.""" self._generate_and_log_text(state, logger, "BEFORE_EVAL") - -from llmfoundry.registry import callbacks -callbacks.register('text_generation', func=TextGenerationCallback) \ No newline at end of file diff --git a/llmfoundry/data/sequence_packing.py b/llmfoundry/data/sequence_packing.py new file mode 100644 index 0000000..ae200ed --- /dev/null +++ b/llmfoundry/data/sequence_packing.py @@ -0,0 +1,497 @@ +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + +import math +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from typing import Any, Generic, Iterable, NamedTuple, Optional, Sequence, TypeVar, Union + +import numpy as np +import torch +from composer.core import Time +from composer.core.types import Batch +from numba import njit + +class BatchSizeWarmupScheduler: + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + warmup_tokens: Union[str, Time, int], + world_size: int, + ): + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + + if isinstance(warmup_tokens, str): + self.warmup_tokens = Time.from_timestring(warmup_tokens).value + elif isinstance(warmup_tokens, Time): + self.warmup_tokens = warmup_tokens.value + else: + self.warmup_tokens = warmup_tokens + self.warmup_tokens = math.ceil(self.warmup_tokens / world_size) + self._step_thresholds = self._calculate_step_thresholds() + + def _calculate_step_thresholds(self): + total_batch_sizes = sum(range(self.min_batch_size, self.max_batch_size)) + steps_per_unit = self.warmup_tokens / total_batch_sizes + + thresholds = [] + cumsum = 0 + for batch_size in range(self.min_batch_size, self.max_batch_size): + cumsum += batch_size + steps = math.ceil(steps_per_unit * cumsum) + thresholds.append(steps) + return thresholds + + def __call__(self, current_step: int) -> int: + if current_step >= self.warmup_tokens: + return self.max_batch_size + + for i, threshold in enumerate(self._step_thresholds): + if current_step < threshold: + return self.min_batch_size + i + + # should never hit this, but just in case + return self.max_batch_size + +class SequencePacker(ABC): + def __init__( + self, + # params defining the incoming batches of seqs + src_iterable: Iterable[list[list[int]]], + src_batch_size: int, + src_max_seq_len: int, + # params defining outgoing batches of pseqs + out_batch_size: int, + out_pseq_len: int, + # params defining internal behavior + buffer_size: int, + pad_token_id: int = -1, + ignore_token_id: int = -100, + seed=42, + batch_size_warmup_min_size: Optional[int] = None, + batch_size_warmup_tokens: Optional[Union[str, Time]] = None, + world_size: int = 1, + ): + """Takes batches of unpacked, unpadded sequences (seqs) to batches of packed and padded sequences (pseqs). + + Every input batch must be a list[list[int]], a list of variable-length sequences of tokens. + + Every output batch is a dict with packed sequences and metadata. + + Args: + src_iterable: An iterable (e.g., a DataLoader), whose iterator yields one incoming batch, + where a batch is a list of unpadded, variable-length Sequences of token + IDs. + + src_batch_size: This is the INCOMING batch size, the number of seqs in one batch yielded + from `src_iterable`'s iterator. + + src_max_seq_len: The maximum number of tokens in a seq within an incoming batch. + + out_batch_size: the number of pseqs (packed seqs) in one outgoing batch + + out_pseq_len: the number of tokens per packed seq, in every outgoing batch + + buffer_size: The maximum number of seqs which may be buffered internally. + + pad_token_id: The token ID used for padding the space which cannot be filled to reach out_pseq_len. + + ignore_token_id: The token ID used to ignore tokens in labels. + + seed: Random seed for internal buffering and shuffling to ensure reproducibility. + + batch_size_warmup_min_size: If not None, the sequence packer will gradually increase the batch size from batch_size_warmup_min_size to out_batch_size over the course of the warmup_tokens. + + batch_size_warmup_tokens: If not None, the sequence packer will gradually increase the batch size from batch_size_warmup_min_size to out_batch_size over the course of the warmup_tokens. + + world_size: The number of processes participating in this training run. + """ + assert buffer_size >= out_batch_size, f"required that {buffer_size=} >= {out_batch_size=}" + self.src_dataloader_len = len(src_iterable) + self.src_iterable = src_iterable + self.src_batch_size = src_batch_size + self.out_batch_size = out_batch_size + self.out_pseq_len = out_pseq_len + self.buffer_size = buffer_size + self.pad_token_id = pad_token_id + self.ignore_token_id = ignore_token_id + # internals + self.buffer = deque() # internal buffer holds individual seqs, as tensors. + # for stats to report packing efficiency. + self._seqs_consumed = 0 + self._seqs_emitted = 0 + # Set random seed + self.seed = seed + self.epoch = -1 + self._token_count = 0 + self.batch_size_scheduler = None + if batch_size_warmup_min_size is not None and batch_size_warmup_tokens is not None: + self.batch_size_scheduler = BatchSizeWarmupScheduler( + batch_size_warmup_min_size, out_batch_size, batch_size_warmup_tokens, world_size, + ) + else: + self.batch_size_scheduler = None + + @property + def seqs_emitted(self): + """Number of seqs, incoming from src_iterable, which have been emitted in OUTGOING batches.""" + return self._seqs_emitted + + @property + def seqs_consumed(self): + """Number of seqs, incoming from src_iterable, which have been consumed.""" + return self._seqs_consumed + + def _reset_state(self): + self.epoch += 1 + self.buffer.clear() + self._seqs_consumed = 0 + self._seqs_emitted = 0 + self.np_rng = np.random.default_rng(self.epoch + self.seed) + + # Update the epoch for the sampler + if isinstance(self.src_iterable, torch.utils.data.dataloader.DataLoader): + if isinstance(self.src_iterable.sampler, torch.utils.data.distributed.DistributedSampler): + self.src_iterable.sampler.set_epoch(self.epoch) + + def __iter__(self): + self._reset_state() + self.src_iterator = iter(self.src_iterable) + return self._generate_batches() + + def __len__(self): + # rather than estimate the packed length of the dataset, we rely on Composer's ability + # to schedule training the using the number of batches or tokens instead of epochs. + return None # noqa: PLE0303 + + def _fill_buffer(self, max_items_to_add=float("inf")) -> int: + """Refills the internal buffer. + + - max_items_to_add: an amount less than or equal to the number of items to add + + Returns: the number of items actually added. + """ + items_added = 0 + # NOTE: this should be >=, kept as is to match model training code + # TODO: change if training a new model + while (self.buffer_size - len(self.buffer)) > self.src_batch_size: + try: + # if pulling another batch would fetch more than the requested max, stop + if max_items_to_add < float("inf"): + if (items_added + self.src_batch_size) > max_items_to_add: + break + incoming_batch = next(self.src_iterator) + assert len(incoming_batch) <= self.src_batch_size, ( + f"expected {len(incoming_batch)=} <= {self.src_batch_size=}" + ) + for item in incoming_batch: + # Handle both dict format and tensor format + if isinstance(item, dict) and "input_ids" in item: + input_ids = item["input_ids"] + elif hasattr(item, "input_ids"): + input_ids = item.input_ids + else: + # If it's a tensor directly, use it + input_ids = item + + # Convert to numpy array if it's a tensor + if hasattr(input_ids, 'numpy'): + input_ids = input_ids.numpy() + elif hasattr(input_ids, 'tolist'): + input_ids = np.array(input_ids) + + if len(input_ids) > 0: # ignore empty sequences + self.buffer.append(input_ids) + items_added += 1 + self._seqs_consumed += 1 + except StopIteration: + break + return items_added + + def _generate_batches(self): + """Generates batches of packed sequences. + + The returned generator's iterator will always, when next() is called on it, either: + - return a valid dict batch + - raise StopIteration + """ + while True: + retval = self._create_batch() + if retval is None: + break + batch, lst_cu_seq_lens = retval + + assert isinstance(retval, tuple), f"Unexpected {type(retval)=}" + assert isinstance(retval[0], np.ndarray), f"Unexpected {type(retval[0])=}" + assert isinstance(retval[1], list), f"Unexpected {type(retval[1])=}" + + cu_seq_lens = [torch.tensor(x, dtype=torch.int32) for x in lst_cu_seq_lens] + max_seq_lens = [torch.max(x[1:] - x[:-1]).item() for x in cu_seq_lens] + assert isinstance(cu_seq_lens, list), f"Unexpected {type(cu_seq_lens)=}" + + # For decoder models, labels should be the same as input_ids + # The model will internally roll them by -1 for causal language modeling + labels = batch.copy() + # Set padding tokens to ignore index in labels + labels = np.where(batch == self.pad_token_id, self.ignore_token_id, labels) + + yieldval = { + "input_ids": torch.from_numpy(batch), + "labels": torch.from_numpy(labels), + "cu_seqlens": cu_seq_lens, + "max_seqlen": max_seq_lens, + "attention_mask": torch.from_numpy(np.where(batch == self.pad_token_id, 0, 1)), + } + self._token_count += yieldval["attention_mask"].sum().item() + yield yieldval + + @abstractmethod + def _create_batch(self) -> Optional[tuple[np.ndarray, list[list[int]]]]: + """Returns a batch of packed sequences with its cumulative seq length information. + + Or else, returns None if it cannot build a full outgoing batch. + + Must mutate self.buffer to remove the sequences that are packed into the batch. + + Returns: + (out_batch,cumulative_seq_len):tuple[np.ndarray, list[list[int]]] + where: + - out_batch is a numpy array of shape (out_batch_size, out_pseq_len); + - cum_seq_lens is a list of lists, where the outer list is of len out_batch_size, + and each inner list is of varying length, and contains the start positions of + every seq in the pseq, and the end position of the last seq in the pseq. + """ + pass + + +@njit +def find_best_fit(remaining_spaces, seq_len): + valid_spaces = seq_len <= remaining_spaces + if np.any(valid_spaces): + valid_space_sizes = remaining_spaces[valid_spaces] + best_fit_idx = np.argmin(valid_space_sizes) + return np.arange(len(remaining_spaces))[valid_spaces][best_fit_idx] + return -1 + + +class GreedyBestFitSequencePacker(SequencePacker): + @classmethod + def from_composer( + cls, + src_iterable: Iterable[list[list[int]]], + batch_size: int = 512, + micro_batch_size: int = 32, + max_seq_len: int = 1024, + buffer_size: int = 5120, + # token values + pad_token_id: int = -1, + ignore_token_id: int = -100, + # transform values + seed=42, + batch_size_warmup_min_size: Optional[int] = None, + batch_size_warmup_tokens: Optional[Union[str, Time]] = None, + world_size: int = 1, + ) -> "GreedyBestFitSequencePacker": + if batch_size_warmup_min_size is not None: + if batch_size_warmup_min_size % micro_batch_size != 0: + raise ValueError(f"{batch_size_warmup_min_size=} must be a multiple of {micro_batch_size=}") + batch_size_warmup_min_size = int(batch_size_warmup_min_size / micro_batch_size) + return cls( + # input shape + src_iterable=src_iterable, + src_batch_size=batch_size, + src_max_seq_len=max_seq_len, + # output shape + out_batch_size=int(batch_size / micro_batch_size), + out_pseq_len=int(micro_batch_size * max_seq_len), + # internal + buffer_size=buffer_size, + # transformation + pad_token_id=pad_token_id, + ignore_token_id=ignore_token_id, + seed=seed, + batch_size_warmup_min_size=batch_size_warmup_min_size, + batch_size_warmup_tokens=batch_size_warmup_tokens, + world_size=world_size, + ) + + def _create_batch(self) -> Optional[tuple[np.ndarray, list[list[int]]]]: + if self.batch_size_scheduler: + self.out_batch_size = self.batch_size_scheduler(self._token_count) + + batch = np.full( + (self.out_batch_size, self.out_pseq_len), self.pad_token_id, dtype=np.int64, + ) # the pseqs being constructed + seq_counts = np.zeros(self.out_batch_size, dtype=np.int32) # the count of seqs per pseq + cum_seq_lens = [[0] for _ in range(self.out_batch_size)] + remaining_spaces = np.full( + (self.out_batch_size,), self.out_pseq_len, dtype=np.int32, + ) # the space remaining per pseq + temp_buffer = [] + + while True: + # Check if buffer has more items, and if not replenish + if not self.buffer: + items_to_fetch = self.buffer_size - len(temp_buffer) + items_added = self._fill_buffer(items_to_fetch) + if items_added == 0: + break + + seq = self.buffer.popleft() + seq_len = len(seq) + + # Find the best fit (smallest space that can accommodate the sequence) + best_fit_idx = find_best_fit(remaining_spaces, seq_len) + if best_fit_idx != -1: + end_pos = self.out_pseq_len - remaining_spaces[best_fit_idx] + batch[best_fit_idx, end_pos : end_pos + seq_len] = seq + seq_counts[best_fit_idx] += 1 + remaining_spaces[best_fit_idx] -= seq_len + cum_seq_lens[best_fit_idx].append(cum_seq_lens[best_fit_idx][-1] + seq_len) + else: + # Can't fit the sequence, save for next batch + temp_buffer.append(seq) + + # Add any sequences we skipped back to the start of the buffer + self.buffer.extendleft(temp_buffer) + + if np.all(seq_counts > 0): + self._seqs_emitted += np.sum(seq_counts) + for x in cum_seq_lens: + if x[-1] != self.out_pseq_len: + x.append(self.out_pseq_len) + return batch, cum_seq_lens + else: + # If we can't form a full batch, we return None to signal the end + return None + + +T = TypeVar("T") + + +class BufferedIterable(Generic[T]): + def __init__(self, iterable: Iterable[T], buffer_size: int): + """Args: + - iterable: an object which generates a fresh iterator on iter() and which implements len() + """ # noqa: D205 + self.iterable = iterable + self.buffer_size = buffer_size + + def __iter__(self): + return BufferedIterator(self.iterable, self.buffer_size) + + @property + def batch_size(self) -> int: + if isinstance(self.iterable, SequencePacker): + return self.iterable.out_batch_size + else: + raise TypeError("Expected a SequencePacker") + + +class BufferedIterator(Generic[T]): + def __init__(self, iterable: Iterable[T], buffer_size: int): + self.iterator = iter(iterable) + self.buffer = deque(maxlen=buffer_size) + self.buffer_size = buffer_size + self.lock = threading.Lock() + self.exhausted = False + self.filler_thread = threading.Thread(target=self._background_fill, daemon=True) + self.filler_thread.start() + + def _background_fill(self): + # Fill up the buffer, whenever possible, in the background + while not self.exhausted: + if len(self.buffer) < self.buffer_size: + try: + item = next(self.iterator) + with self.lock: + self.buffer.append(item) + except StopIteration: + self.exhausted = True + break + else: + time.sleep(0.01) # Sleep for a bit to avoid busy waiting + + def __iter__(self): + return self + + def __next__(self) -> T: + while True: + if not self.buffer: + if self.exhausted: + # We've exhausted the iterator and the buffer so we're done + raise StopIteration + else: + # The buffer is empty but the iterator is not exhausted yet. + # Let's give the filler thread a chance to add items to the buffer + time.sleep(0.01) + else: + with self.lock: + return self.buffer.popleft() + + +def split_packed_batch( + batch: Batch, microbatch_size: Union[int, float], padding_tolerance: float = 1.0, mark_dynamic: bool = True, +) -> Sequence: + # NOTE: Packed sequences are already packed into a microbatch size worth of tokens. + # So to correctly return a microbatch worth of data, we will simply return each item (i.e. microbatch_size 1) + + num_items = batch["input_ids"].shape[0] + split_inputs = [x.squeeze() for x in batch["input_ids"].split(1)] + split_labels = [x.squeeze() for x in batch["labels"].split(1)] + split_attention_masks = [x.squeeze() for x in batch["attention_mask"].split(1)] + split_cu_seqlens = batch["cu_seqlens"] + result = [] + for i in range(num_items): + attention_mask = split_attention_masks[i] + padding_amount = 1 - (attention_mask.sum() / len(attention_mask)) + + if padding_amount > padding_tolerance: + last_non_pad = attention_mask.nonzero().max() + input_ids = split_inputs[i][: last_non_pad + 1] + labels = split_labels[i][: last_non_pad + 1] + cu_seqlens = split_cu_seqlens[i][:-1] + attention_mask = attention_mask[: last_non_pad + 1] + else: + input_ids = split_inputs[i] + labels = split_labels[i] + cu_seqlens = split_cu_seqlens[i] + if mark_dynamic: + torch._dynamo.mark_dynamic(cu_seqlens, index=0) + microbatch = { + "input_ids": input_ids, + "labels": labels, + "cu_seqlens": cu_seqlens, + "max_seqlen": batch["max_seqlen"][i], + "attention_mask": attention_mask, + } + result.append(microbatch) + + assert all([x["input_ids"].shape[-1] == y["cu_seqlens"][-1] for x, y in zip(result, result)]) # noqa: C419 + return result + + +def get_num_tokens_in_packed_batch(batch: Batch, ignore_index: int = -100) -> int: + labels: torch.Tensor | list[torch.Tensor] = batch["labels"] + if isinstance(labels, torch.Tensor): + return (labels != ignore_index).sum().item() + elif isinstance(labels, list): + return sum([(x != ignore_index).sum().item() for x in labels]) + else: + raise TypeError('Expected a batch with a "labels" key of type list[Tensor] or Tensor') + + +def get_num_samples_in_packed_batch(batch: Batch) -> int: + # Number of sequences can be inferred from cu_seqlens arrays + cu_seqlens: torch.Tensor | list[torch.Tensor] = batch["cu_seqlens"] + if isinstance(cu_seqlens, torch.Tensor): + return cu_seqlens.size()[0] - 1 + elif isinstance(cu_seqlens, list): + return sum([x.size()[0] - 1 for x in batch["cu_seqlens"]]) + else: + raise TypeError('Expected a batch with a "cu_seqlens" key of type list[Tensor] or Tensor') + diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index dbbd575..212d3e1 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -27,6 +27,10 @@ SUPPORTED_MDS_ENCODING_TYPES, stream_remote_local_validate, ) +from llmfoundry.data.sequence_packing import ( + BufferedIterable, + GreedyBestFitSequencePacker, +) from llmfoundry.utils.registry_utils import construct_from_registry __all__ = [ @@ -314,6 +318,21 @@ def build_text_dataloader( dataset_cfg = dataset + # Check if sequence packing is enabled + sequence_packing = dataset_cfg.get('sequence_packing', False) + if sequence_packing: + return _build_sequence_packing_dataloader( + tokenizer=tokenizer, + device_batch_size=device_batch_size, + dataset=dataset, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + timeout=timeout, + ) + # get kwargs dataset_cfg['replication'], dataset_batch_size = construct_from_registry( name='dataset_replication_validator', @@ -397,6 +416,104 @@ def build_text_dataloader( ) +def _build_sequence_packing_dataloader( + tokenizer: PreTrainedTokenizerBase, + device_batch_size: Union[int, float], + dataset: dict[str, Any], + drop_last: bool, + num_workers: int, + pin_memory: bool = True, + prefetch_factor: int = 2, + persistent_workers: bool = True, + timeout: int = 0, +) -> DataSpec: + """Build a dataloader with sequence packing for decoder models.""" + from composer.utils import dist + + dataset_cfg = dataset + + micro_batch_size = dataset_cfg.get('micro_batch_size', device_batch_size) + max_seq_len = dataset_cfg.get('max_seq_len', 2048) + packing_buffer_size = dataset_cfg.get('packing_buffer_size', 5 * device_batch_size) + batch_size_warmup_min_size = dataset_cfg.get('batch_size_warmup_min_size', None) + batch_size_warmup_tokens = dataset_cfg.get('batch_size_warmup_tokens', None) + packing_prefetch_factor = dataset_cfg.get('packing_prefetch_factor', 5) + + dataset_cfg['replication'], dataset_batch_size = construct_from_registry( + name='dataset_replication_validator', + registry=registry.dataset_replication_validators, + partial_function=False, + kwargs={ + 'dataset_cfg': dataset_cfg, + 'tokenizer': tokenizer, + 'device_batch_size': device_batch_size, + }, + ) + + streams = build_streams( + streams=dataset_cfg.pop('streams') + if 'streams' in dataset_cfg else None, + ) + + valid_streaming_text_dataset_parameters = inspect.signature( + StreamingTextDataset, + ).parameters + + valid_base_dataset_params = inspect.signature(StreamingDataset,).parameters + + dataset_config_subset_for_streaming_text_dataset = { + k: v + for k, v in dataset_cfg.items() + if k in valid_streaming_text_dataset_parameters or + k in valid_base_dataset_params + } + + text_dataset = StreamingTextDataset( + tokenizer=tokenizer, + streams=streams, + batch_size=dataset_batch_size, + **dataset_config_subset_for_streaming_text_dataset, + ) + + dataloader = DataLoader( + text_dataset, + collate_fn=lambda x: x, # No collation, just pass through + batch_size=dataset_batch_size, + drop_last=False, # Don't drop last for sequence packing + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + timeout=timeout, + ) + + sequence_packer = GreedyBestFitSequencePacker.from_composer( + src_iterable=dataloader, + batch_size=dataset_batch_size, + micro_batch_size=micro_batch_size, + max_seq_len=max_seq_len, + buffer_size=packing_buffer_size, + pad_token_id=tokenizer.pad_token_id, + ignore_token_id=-100, # Standard ignore index for cross entropy + seed=dataset_cfg.get('shuffle_seed', 42), + batch_size_warmup_min_size=batch_size_warmup_min_size, + batch_size_warmup_tokens=batch_size_warmup_tokens, + world_size=dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1, + ) + + buffered_packer = BufferedIterable(sequence_packer, buffer_size=packing_prefetch_factor) + + return construct_from_registry( + name='data_spec', + registry=registry.data_specs, + partial_function=False, + kwargs={ + 'dl': buffered_packer, + 'dataset_cfg': dataset_cfg, + }, + ) + + # Helpful to test if your dataloader is working locally # Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out if __name__ == '__main__': diff --git a/llmfoundry/loggers/__init__.py b/llmfoundry/loggers/__init__.py index ae849ff..e2e3188 100644 --- a/llmfoundry/loggers/__init__.py +++ b/llmfoundry/loggers/__init__.py @@ -10,7 +10,7 @@ WandBLogger, ) -from llmfoundry.loggers.composer_aim_logger import AimLogger +# from llmfoundry.loggers.composer_aim_logger import AimLogger from llmfoundry.registry import loggers @@ -23,5 +23,5 @@ ) # for backwards compatibility loggers.register('mlflow', func=MLFlowLogger) loggers.register('mosaicml', func=MosaicMLLogger) -loggers.register('aim', func=AimLogger) +# loggers.register('aim', func=AimLogger) loggers.register('file_logger', func=FileLogger) diff --git a/llmfoundry/models/llama/custom_model.py b/llmfoundry/models/llama/custom_model.py index 4b72df9..e46001d 100644 --- a/llmfoundry/models/llama/custom_model.py +++ b/llmfoundry/models/llama/custom_model.py @@ -18,6 +18,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 + from typing import Any, Optional, Union, TYPE_CHECKING if TYPE_CHECKING: @@ -28,10 +31,16 @@ from torch import nn from transformers.models.llama.configuration_llama import LlamaConfig from transformers import PreTrainedTokenizerBase, PreTrainedModel +from transformers.utils import is_flash_attn_2_available from llmfoundry.data.finetuning.collator import CROSS_ENTROPY_IGNORE_INDEX from transformers import AutoModelForCausalLM +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.layers.rotary import RotaryEmbedding + from flash_attn.ops.triton.rotary import apply_rotary + SMOLLM2_CONFIG_135M = LlamaConfig( attention_bias = False, attention_dropout = 0.0, @@ -59,8 +68,195 @@ transformers_version = "4.55.0.dev0", use_cache = True, vocab_size = 49152, + _attn_implementation = "sdpa", ) +# Modernbert unpadding and repadding +def _unpad_modernbert_input( + inputs: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Remove padding from input sequences. + + Args: + inputs: (batch, seqlen, ...) or (batch, seqlen) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + position_ids: (batch, seqlen), int, position ids + labels: (batch, seqlen), int, labels + + Returns: + unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + cu_seqlens: (batch + 1), the cumulative sequence lengths + max_seqlen_in_batch: int + unpadded_position_ids: (total_nnz) or None + unpadded_labels: (total_nnz) or None + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = int(seqlens_in_batch.max().item()) + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + if inputs.dim() == 2: + unpadded_inputs = inputs.flatten()[indices] + else: + batch, seqlen, *rest = inputs.shape + shape = batch * seqlen + unpadded_inputs = inputs.view(shape, *rest)[indices] + + unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None + unpadded_labels = labels.flatten()[indices] if labels is not None else None + + return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels + +def _pad_modernbert_output( + inputs: torch.Tensor, + indices: torch.Tensor, + batch: int, + seqlen: int, +) -> torch.Tensor: + """Add padding to sequences. + + Args: + inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. + indices: (total_nnz) + batch: int, batch size + seqlen: int, max sequence length + + Returns: + padded_inputs: (batch, seqlen, ...) or (batch, seqlen) + """ + if inputs.dim() == 1: + output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen) + else: + _, *rest = inputs.shape + output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) + output[indices] = inputs + padded_inputs = output.view(batch, seqlen, *rest) + + return padded_inputs + +class ApplyRotaryEmbUnpad(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + # (total_nnz, nheads, headdim) + apply_rotary( + x, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=False, + inplace=True, + ) + + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.max_seqlen = max_seqlen + return x + + @staticmethod + def backward(ctx, do): + cos, sin, cu_seqlens = ctx.saved_tensors + apply_rotary( + do, + cos, + sin, + seqlen_offsets=0, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=False, + inplace=True, + conjugate=True, + ) + + return do, None, None, None, None, None, None + + +def apply_rotary_unpadded( + x, + cos, + sin, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """Arguments: + x: (total_nnz, nheads, headdim) - input tensor for packed QKV. + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (total_nnz, dim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ # noqa: D205 + return ApplyRotaryEmbUnpad.apply(x, cos, sin, cu_seqlens, max_seqlen) + + +class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding): + """The rotary position embeddings applied directly to unpadded sequences.""" + + def __init__( + self, + dim: int, + base: float = 10000.0, + max_seqlen: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache + up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, + the cos_sin_cache will be recomputed during the forward pass. + """ # noqa: D205 + super().__init__(dim=dim, base=base, device=device, interleaved=False) + self.max_seqlen = max_seqlen + + if max_seqlen is not None and device is not None and dtype is not None: + self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding *inplace* to x. + x: (total_nnz, nheads, headdim) + cu_seqlens: (batch + 1,) cumulative sequence lengths + max_seqlen: int max seq length in the batch + """ # noqa: D205 + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype) + + x = apply_rotary_unpadded( + x, + self._cos_cached, + self._sin_cached, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() @@ -128,6 +324,101 @@ def __init__(self, config: LlamaConfig): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +def flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + is_causal: bool, + scaling: float, + enable_gqa: bool, + input_shape: tuple[int, int], + hidden_shape: tuple[int, int], + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + cu_seqlens: torch.Tensor, + max_seqlen: int, + rotary_emb: Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding], + target_dtype: torch.dtype = torch.bfloat16, + **kwargs: Any, +) -> tuple[torch.Tensor]: + # (total_seqlen, nheads, headdim) + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + query_states = rotary_emb(query_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + key_states = rotary_emb(key_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + convert_dtype = query_states.dtype not in (torch.float16, torch.bfloat16) + + if convert_dtype: + # FA2 implementation only supports fp16 and bf16. If FA2 is supported, + # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) + orig_dtype = query_states.dtype + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + deterministic=False, + causal=is_causal, + ) + attn = attn.to(orig_dtype) # type: ignore + else: + attn = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + deterministic=False, + causal=is_causal, + ) + total_tokens = attn.shape[0] + hidden_size = attn.shape[1] * attn.shape[2] # num_heads * head_dim + return attn.view(total_tokens, hidden_size) + +def sdpa_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + is_causal: bool, + scaling: float, + enable_gqa: bool, + input_shape: tuple[int, int], + hidden_shape: tuple[int, int], + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]], + **kwargs: Any, +) -> torch.Tensor: + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn_output = nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=is_causal, + scale=scaling, enable_gqa=enable_gqa).transpose(1,2) + + attn_output = attn_output.reshape(*input_shape, -1) + return attn_output + +LLAMA_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "sdpa": sdpa_attention_forward, +} + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: int): @@ -148,23 +439,27 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + rotary_emb: Optional[Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding]] = None, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - attn_output = nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, is_causal=self.is_causal, - scale=self.scaling, enable_gqa=True).transpose(1,2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + attn_output = LLAMA_ATTENTION_FUNCTION[self.config._attn_implementation]( + query_states, key_states, value_states, self.is_causal, self.scaling, + True, input_shape, hidden_shape, position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=rotary_emb, + **kwargs, + ) - attn_output = attn_output.reshape(*input_shape, -1) attn_output = self.o_proj(attn_output) return attn_output @@ -181,6 +476,10 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + rotary_emb: Optional[Union[LlamaRotaryEmbedding, ModernBertUnpaddedRotaryEmbedding]] = None, **kwargs: Any, ) -> tuple[torch.Tensor]: @@ -189,6 +488,10 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, + attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=rotary_emb, **kwargs, ) hidden_states = residual + hidden_states @@ -273,31 +576,69 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(config=config) self.can_generate = True self.tie_weights() def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[tuple] = None, use_cache: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs: Any, ): - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + batch_size, seq_len = input_ids.shape[:2] + repad = False + cu_seqlens = None + max_seqlen = None + indices = None + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is None: + attention_mask = (input_ids != self.config.eos_token_id).long() + if inputs_embeds is None: + repad = True + with torch.no_grad(): + input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( + inputs=input_ids, attention_mask=attention_mask, + ) + else: + attention_mask = None + inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)) + + # For flash attention, we don't need position_embeddings since ModernBertUnpaddedRotaryEmbedding handles it + if self.config._attn_implementation == "flash_attention_2": + device = hidden_states.device if hidden_states is not None else input_ids.device + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( + dim=self.config.head_dim, + base=self.config.rope_theta, + max_seqlen=self.config.max_position_embeddings, + device=device, + ) + position_embeddings = None + else: + device = hidden_states.device if hidden_states is not None else input_ids.device + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + position_embeddings = self.rotary_emb(hidden_states, torch.arange(seq_len, device=device).unsqueeze(0)) for decoder_layer in self.layers[:self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, + attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_emb=self.rotary_emb, **kwargs, ) - return self.lm_head(self.norm(hidden_states)) + hidden_states = self.norm(hidden_states) + if repad: + hidden_states = _pad_modernbert_output( + inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len, + ) + return self.lm_head(hidden_states) @classmethod def from_pretrained(cls, model_type: str, device_map: str = "auto", torch_dtype: torch.dtype = torch.bfloat16): @@ -372,16 +713,6 @@ def __init__( ) def forward(self, batch: dict[str, Any]) -> torch.Tensor: - # input_ids = batch['input_ids'] - - # # Create attention mask if not provided (mark non-padding tokens as 1) - # attention_mask = batch.get('attention_mask') - # if attention_mask is None: - # # Assume padding token is 0 (EOS token based on your config) - # attention_mask = (input_ids != 0).long() - # print('attention_mask:', attention_mask) - # print('sum of attention_mask:', attention_mask.sum(dim=1)) - # return self.model(input_ids=input_ids, attention_mask=attention_mask) return self.model(input_ids=batch['input_ids']) def loss(self, outputs: torch.Tensor, batch: dict[str, Any]) -> torch.Tensor: diff --git a/pyproject.toml b/pyproject.toml index d19fa3a..99a5cc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -556,6 +556,7 @@ dependencies = [ "aim>=3.26.0,<4", "zstd>=1.5.6.1,!=1.5.6.2", "math_verify>=0.6.0", + "numba>=0.62.1", ] # Extra group for development requirements @@ -621,4 +622,4 @@ explicit = true [[tool.uv.index]] name = "pytorch-gpu" url = "https://download.pytorch.org/whl/cu124" -explicit = true \ No newline at end of file +explicit = true diff --git a/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml b/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml index 60a14c3..2b40dfa 100644 --- a/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml +++ b/scripts/train/yamls/pretrain/custom_smollm2-135m.yaml @@ -1,22 +1,22 @@ variables: data_local: datasets/tmp # datasets/c4_small - data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/1k/ # If blank, files must be present in data_local + data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/full/ tokenizer_name: HuggingFaceTB/SmolLM2-135M global_seed: 42 - max_seq_len: 2048 + max_seq_len: 1024 run_name: custom_smollm2_135m_training + device_microbatch_size: 4 + global_train_batch_size: 8 max_seq_len: ${variables.max_seq_len} run_name: ${variables.run_name} -# Model - Using custom SmolLM2-135M model model: name: custom-smollm2-135m model_type: smollm2-135m pretrained: false pretrained_model_name_or_path: HuggingFaceTB/SmolLM2-135M -# Tokenizer tokenizer: name: ${variables.tokenizer_name} kwargs: @@ -24,9 +24,9 @@ tokenizer: loggers: wandb: - project: 'smollm2-135m-training-avelina' + project: 'sequence-packing' entity: 'local-research-group' - name: 'smollm2-135m-training-avelina-1k' + name: 'avelina-sequence-packing-2000ba-max-seq-len-1024_with-eval' # Data loaders train_loader: @@ -40,6 +40,14 @@ train_loader: shuffle_seed: ${variables.global_seed} deorder_only_format: true eos_token_id: 0 + # Sequence packing configuration + sequence_packing: true + micro_batch_size: ${variables.device_microbatch_size} # Should match device_train_microbatch_size + packing_buffer_size: 512 # 5 * device_batch_size + packing_prefetch_factor: 512 + # Optional: batch size warmup + # batch_size_warmup_min_size: 2 + # batch_size_warmup_tokens: 1000ba drop_last: true num_workers: 4 @@ -54,13 +62,21 @@ eval_loader: shuffle_seed: ${variables.global_seed} deorder_only_format: true eos_token_id: 0 + # Sequence packing configuration + sequence_packing: true + micro_batch_size: ${variables.device_microbatch_size} # Should match device_train_microbatch_size + packing_buffer_size: 512 # 5 * device_batch_size + packing_prefetch_factor: 512 + # Optional: batch size warmup + # batch_size_warmup_min_size: 2 + # batch_size_warmup_tokens: 1000ba drop_last: false num_workers: 4 # Training configuration scheduler: name: cosine_with_warmup - t_warmup: 100ba + t_warmup: 100ba # 8192tok alpha_f: 0.1 optimizer: @@ -78,18 +94,18 @@ algorithms: clipping_threshold: 1.0 # Training duration and evaluation -max_duration: 1000ba -eval_interval: 100ba +max_duration: 100ba # 16384tok +eval_interval: 100ba # 16384tok eval_first: false -eval_subset_num_batches: 5 +eval_subset_num_batches: 10 save_overwrite: true save_interval: 0ba # Disable checkpointing # System configuration seed: ${variables.global_seed} -device_eval_batch_size: 2 -device_train_microbatch_size: 4 -global_train_batch_size: 4 +device_eval_batch_size: 4 +device_train_microbatch_size: ${variables.device_microbatch_size} +global_train_batch_size: ${variables.global_train_batch_size} precision: amp_bf16 # FSDP configuration @@ -101,6 +117,14 @@ fsdp_config: activation_cpu_offload: false limit_all_gathers: true +# Torch compile configuration +compile_config: + dynamic: true # Enable dynamic shapes + options: + epilogue_fusion: true # Fuses pointwise ops into templates (requires max-autotune) + max_autotune: true # Enables automatic tuning of fusion patterns + coordinate_descent_tuning: true # Enables coordinate descent tuning for optimization + # Logging progress_bar: true log_to_console: true @@ -112,11 +136,20 @@ callbacks: lr_monitor: {} memory_monitor: {} runtime_estimator: {} - text_generation: - prompts: - - "The future of artificial intelligence is" - - "In a world where technology" - - "def main():" - - "Gravity is" - max_new_tokens: 50 - temperature: 0. \ No newline at end of file + # batch_inspection: + # log_frequency: 1 + # sample_size: 1 + # log_to_console: true + # log_to_wandb: false + packing_efficiency: + log_interval: 100 + # text_generation: + # prompts: + # - "The future of artificial intelligence is" + # - "In a world where technology" + # - "def main():" + # - "Gravity is" + # - "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Local Research Group<|im_end|>\n<|im_start|>user\nWhat is 1+1?<|im_end|>\n<|im_start|>assistant\n" + # max_new_tokens: 50 + # temperature: 0.7 + # log_to_wandb: true \ No newline at end of file diff --git a/scripts/train/yamls/pretrain/custom_smollm2-135m_peft.yaml b/scripts/train/yamls/pretrain/custom_smollm2-135m_peft.yaml index 85e40cd..f2a97f0 100644 --- a/scripts/train/yamls/pretrain/custom_smollm2-135m_peft.yaml +++ b/scripts/train/yamls/pretrain/custom_smollm2-135m_peft.yaml @@ -1,10 +1,12 @@ variables: data_local: datasets/tmp # datasets/c4_small - data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/1k/ # If blank, files must be present in data_local + data_remote: hf://datasets/LocalResearchGroup/split-avelina-python-edu/tokenized/avelinapythonedu/full/ # If blank, files must be present in data_local tokenizer_name: HuggingFaceTB/SmolLM2-135M global_seed: 42 - max_seq_len: 2048 + max_seq_len: 1024 run_name: custom_smollm2_135m_peft_training + device_microbatch_size: 2 + global_train_batch_size: 8 max_seq_len: ${variables.max_seq_len} run_name: ${variables.run_name} @@ -17,9 +19,9 @@ model: peft_config: peft_type: LORA task_type: CAUSAL_LM - r: 64 - lora_alpha: 128 - lora_dropout: 0.05 + r: 1 + lora_alpha: 2 + lora_dropout: 0.0 target_modules: - q_proj - k_proj @@ -38,9 +40,9 @@ tokenizer: loggers: wandb: - project: 'smollm2-135m-peft-training-avelina' + project: 'sequence-packing' entity: 'local-research-group' - name: 'smollm2-135m-peft-training-avelina-1k' + name: 'avelina-python-edu-lora-sequence-packing-100ba-max-seq-len-1024' train_loader: name: text @@ -53,6 +55,14 @@ train_loader: shuffle_seed: ${variables.global_seed} deorder_only_format: true eos_token_id: 0 + # Sequence packing configuration + sequence_packing: true + micro_batch_size: ${variables.device_microbatch_size} # Should match device_train_microbatch_size + packing_buffer_size: 512 # 5 * device_batch_size + packing_prefetch_factor: 512 + # Optional: batch size warmup + # batch_size_warmup_min_size: 2 + # batch_size_warmup_tokens: 1000ba drop_last: true num_workers: 4 @@ -67,12 +77,20 @@ eval_loader: shuffle_seed: ${variables.global_seed} deorder_only_format: true eos_token_id: 0 + # Sequence packing configuration + sequence_packing: true + micro_batch_size: ${variables.device_microbatch_size} # Should match device_train_microbatch_size + packing_buffer_size: 512 # 5 * device_batch_size + packing_prefetch_factor: 512 + # Optional: batch size warmup + # batch_size_warmup_min_size: 2 + # batch_size_warmup_tokens: 1000ba drop_last: false num_workers: 4 scheduler: name: cosine_with_warmup - t_warmup: 100ba + t_warmup: 100ba # 8192tok alpha_f: 0.1 optimizer: @@ -89,26 +107,35 @@ algorithms: clipping_type: norm clipping_threshold: 1.0 -max_duration: 100ba -eval_interval: 200ba +max_duration: 1ba # 16384tok +eval_interval: 1ba # 16384tok eval_first: false -eval_subset_num_batches: 50 +eval_subset_num_batches: 10 save_overwrite: true -save_interval: 200ba # Save PEFT checkpoints +save_interval: 0ba # Disable checkpointing seed: ${variables.global_seed} -device_eval_batch_size: 2 -device_train_microbatch_size: 2 -global_train_batch_size: 4 +device_eval_batch_size: 4 +device_train_microbatch_size: ${variables.device_microbatch_size} +global_train_batch_size: ${variables.global_train_batch_size} precision: amp_bf16 fsdp_config: sharding_strategy: FULL_SHARD mixed_precision: PURE - activation_checkpointing: false + activation_checkpointing: false # Disable for PEFT compatibility activation_checkpointing_reentrant: false activation_cpu_offload: false limit_all_gathers: true + # use_orig_params: true # Required for PEFT + FSDP + +# Torch compile configuration +# compile_config: +# dynamic: true # Enable dynamic shapes +# options: +# epilogue_fusion: true # Fuses pointwise ops into templates (requires max-autotune) +# max_autotune: true # Enables automatic tuning of fusion patterns +# coordinate_descent_tuning: true # Enables coordinate descent tuning for optimization progress_bar: true log_to_console: true @@ -120,11 +147,20 @@ callbacks: lr_monitor: {} memory_monitor: {} runtime_estimator: {} - text_generation: - prompts: - - "The future of artificial intelligence is" - - "In a world where technology" - - "The most important thing to remember is" - - "Gravity is" - max_new_tokens: 50 - temperature: 0. \ No newline at end of file + # batch_inspection: + # log_frequency: 1 + # sample_size: 1 + # log_to_console: true + # log_to_wandb: false + packing_efficiency: + log_interval: 1 + # text_generation: + # prompts: + # - "The future of artificial intelligence is" + # - "In a world where technology" + # - "def main():" + # - "Gravity is" + # - "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Local Research Group<|im_end|>\n<|im_start|>user\nWhat is 1+1?<|im_end|>\n<|im_start|>assistant\n" + # max_new_tokens: 50 + # temperature: 0.7 + # log_to_wandb: true \ No newline at end of file diff --git a/simple_train_smollm2.py b/simple_train_smollm2.py index bd0213c..17b805b 100644 --- a/simple_train_smollm2.py +++ b/simple_train_smollm2.py @@ -14,10 +14,6 @@ from llmfoundry.command_utils.train import train from omegaconf import OmegaConf -# Import callbacks to register them -import text_generation_callback # type: ignore -import batch_inspection_callback # type: ignore - # Set up logging logging.basicConfig( level=logging.INFO, diff --git a/simple_train_smollm2_peft.py b/simple_train_smollm2_peft.py index 6a3e21c..f4fd734 100644 --- a/simple_train_smollm2_peft.py +++ b/simple_train_smollm2_peft.py @@ -13,8 +13,6 @@ from llmfoundry.command_utils.train import train from omegaconf import OmegaConf -import text_generation_callback # type: ignore - logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', diff --git a/uv.lock b/uv.lock index 7c27d95..a3a445e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.12.0" resolution-markers = [ "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux' and extra != 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu'", @@ -1538,6 +1539,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/fa/be89a49c640930180657482a74970cdcf6f7072c8d2471e1babe17a222dc/kiwisolver-1.4.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:be4816dc51c8a471749d664161b434912eee82f2ea66bd7628bd14583a833e85", size = 2349213 }, ] +[[package]] +name = "latex2sympy2-extended" +version = "1.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "sympy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/de/472f9115c14c6f6d8a5889cabe3418283d708bde62ce00402c29441deed4/latex2sympy2_extended-1.10.2.tar.gz", hash = "sha256:41a517ffcc5a140e910a7d1646ce6ff440817e5f9d48fc8279d88bd0925bc389", size = 206188 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855 }, +] + [[package]] name = "lightning-utilities" version = "0.12.0" @@ -1566,9 +1580,11 @@ dependencies = [ { name = "fsspec" }, { name = "gitpython" }, { name = "huggingface-hub" }, + { name = "math-verify" }, { name = "mosaicml", extra = ["mlflow", "peft", "wandb"] }, { name = "mosaicml-cli" }, { name = "mosaicml-streaming" }, + { name = "numba" }, { name = "omegaconf" }, { name = "onnx" }, { name = "onnxruntime" }, @@ -1582,8 +1598,8 @@ dependencies = [ [package.optional-dependencies] cpu = [ - { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-cpu') or (platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform == 'darwin' and extra == 'extra-11-llm-foundry-cpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-cpu') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu') or (sys_platform == 'darwin' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform == 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] flash = [ { name = "flash-attn" }, @@ -1591,8 +1607,8 @@ flash = [ gpu = [ { name = "packaging" }, { name = "setuptools" }, - { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "torch", version = "2.5.1+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "platform_machine != 'aarch64' or sys_platform != 'linux'" }, + { name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "torch", version = "2.5.1+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] openai = [ { name = "openai" }, @@ -1625,9 +1641,11 @@ requires-dist = [ { name = "fsspec", specifier = ">=2023.6.0,<2024.12.0" }, { name = "gitpython", specifier = "==3.1.44" }, { name = "huggingface-hub", specifier = ">=0.19.0,<0.29" }, + { name = "math-verify", specifier = ">=0.6.0" }, { name = "mosaicml", extras = ["mlflow", "peft", "wandb"], specifier = ">=0.28.0,<0.29" }, { name = "mosaicml-cli", specifier = ">=0.6.10,<1" }, { name = "mosaicml-streaming", specifier = ">=0.11.0,<0.12" }, + { name = "numba", specifier = ">=0.62.1" }, { name = "omegaconf", specifier = ">=2.2.3,<3" }, { name = "onnx", specifier = "==1.17.0" }, { name = "onnxruntime", specifier = ">=1.19.2,<1.20.2" }, @@ -1644,6 +1662,7 @@ requires-dist = [ { name = "typer", specifier = "<1" }, { name = "zstd", specifier = ">=1.5.6.1,!=1.5.6.2" }, ] +provides-extras = ["gpu", "cpu", "flash", "openai"] [package.metadata.requires-dev] dev = [ @@ -1658,6 +1677,24 @@ dev = [ { name = "toml", specifier = ">=0.10.2,<0.11" }, ] +[[package]] +name = "llvmlite" +version = "0.45.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/8d/5baf1cef7f9c084fb35a8afbde88074f0d6a727bc63ef764fe0e7543ba40/llvmlite-0.45.1.tar.gz", hash = "sha256:09430bb9d0bb58fc45a45a57c7eae912850bedc095cd0810a57de109c69e1c32", size = 185600 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/7c/82cbd5c656e8991bcc110c69d05913be2229302a92acb96109e166ae31fb/llvmlite-0.45.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:28e763aba92fe9c72296911e040231d486447c01d4f90027c8e893d89d49b20e", size = 43043524 }, + { url = "https://files.pythonhosted.org/packages/9d/bc/5314005bb2c7ee9f33102c6456c18cc81745d7055155d1218f1624463774/llvmlite-0.45.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a53f4b74ee9fd30cb3d27d904dadece67a7575198bd80e687ee76474620735f", size = 37253123 }, + { url = "https://files.pythonhosted.org/packages/96/76/0f7154952f037cb320b83e1c952ec4a19d5d689cf7d27cb8a26887d7bbc1/llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b3796b1b1e1c14dcae34285d2f4ea488402fbd2c400ccf7137603ca3800864f", size = 56288211 }, + { url = "https://files.pythonhosted.org/packages/00/b1/0b581942be2683ceb6862d558979e87387e14ad65a1e4db0e7dd671fa315/llvmlite-0.45.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:779e2f2ceefef0f4368548685f0b4adde34e5f4b457e90391f570a10b348d433", size = 55140958 }, + { url = "https://files.pythonhosted.org/packages/33/94/9ba4ebcf4d541a325fd8098ddc073b663af75cc8b065b6059848f7d4dce7/llvmlite-0.45.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e6c9949baf25d9aa9cd7cf0f6d011b9ca660dd17f5ba2b23bdbdb77cc86b116", size = 38132231 }, + { url = "https://files.pythonhosted.org/packages/1d/e2/c185bb7e88514d5025f93c6c4092f6120c6cea8fe938974ec9860fb03bbb/llvmlite-0.45.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:d9ea9e6f17569a4253515cc01dade70aba536476e3d750b2e18d81d7e670eb15", size = 43043524 }, + { url = "https://files.pythonhosted.org/packages/09/b8/b5437b9ecb2064e89ccf67dccae0d02cd38911705112dd0dcbfa9cd9a9de/llvmlite-0.45.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c9f3cadee1630ce4ac18ea38adebf2a4f57a89bd2740ce83746876797f6e0bfb", size = 37253121 }, + { url = "https://files.pythonhosted.org/packages/f7/97/ad1a907c0173a90dd4df7228f24a3ec61058bc1a9ff8a0caec20a0cc622e/llvmlite-0.45.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:57c48bf2e1083eedbc9406fb83c4e6483017879714916fe8be8a72a9672c995a", size = 56288210 }, + { url = "https://files.pythonhosted.org/packages/32/d8/c99c8ac7a326e9735401ead3116f7685a7ec652691aeb2615aa732b1fc4a/llvmlite-0.45.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3aa3dfceda4219ae39cf18806c60eeb518c1680ff834b8b311bd784160b9ce40", size = 55140957 }, + { url = "https://files.pythonhosted.org/packages/09/56/ed35668130e32dbfad2eb37356793b0a95f23494ab5be7d9bf5cb75850ee/llvmlite-0.45.1-cp313-cp313-win_amd64.whl", hash = "sha256:080e6f8d0778a8239cd47686d402cb66eb165e421efa9391366a9b7e5810a38b", size = 38132232 }, +] + [[package]] name = "mako" version = "1.3.9" @@ -1729,6 +1766,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 }, ] +[[package]] +name = "math-verify" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "latex2sympy2-extended" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/b5/b1db6fa6b6c28ebbe1889ee11a4703a72a2ca7750ec415f4559c758cf01a/math_verify-0.8.0.tar.gz", hash = "sha256:3295e0adb94bfe553ff6e3189c44f1916a85aa24ab5d1900f2086a706e28f7c4", size = 60191 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/9f/59979f699b5c97334298f1295bc9fcdc9904d98d2276479bffff863d23b1/math_verify-0.8.0-py3-none-any.whl", hash = "sha256:31ca651296d817a9bb3fd58ca1fd0d192dcea709b1e5ecf2d0a4514c16f89087", size = 29994 }, +] + [[package]] name = "matplotlib" version = "3.10.0" @@ -2040,6 +2089,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, ] +[[package]] +name = "numba" +version = "0.62.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/20/33dbdbfe60e5fd8e3dbfde299d106279a33d9f8308346022316781368591/numba-0.62.1.tar.gz", hash = "sha256:7b774242aa890e34c21200a1fc62e5b5757d5286267e71103257f4e2af0d5161", size = 2749817 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/fa/30fa6873e9f821c0ae755915a3ca444e6ff8d6a7b6860b669a3d33377ac7/numba-0.62.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:1b743b32f8fa5fff22e19c2e906db2f0a340782caf024477b97801b918cf0494", size = 2685346 }, + { url = "https://files.pythonhosted.org/packages/a9/d5/504ce8dc46e0dba2790c77e6b878ee65b60fe3e7d6d0006483ef6fde5a97/numba-0.62.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fa21b0142bcf08ad8e32a97d25d0b84b1e921bc9423f8dda07d3652860eef6", size = 2688139 }, + { url = "https://files.pythonhosted.org/packages/50/5f/6a802741176c93f2ebe97ad90751894c7b0c922b52ba99a4395e79492205/numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ef84d0ac19f1bf80431347b6f4ce3c39b7ec13f48f233a48c01e2ec06ecbc59", size = 3796453 }, + { url = "https://files.pythonhosted.org/packages/7e/df/efd21527d25150c4544eccc9d0b7260a5dec4b7e98b5a581990e05a133c0/numba-0.62.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9315cc5e441300e0ca07c828a627d92a6802bcbf27c5487f31ae73783c58da53", size = 3496451 }, + { url = "https://files.pythonhosted.org/packages/80/44/79bfdab12a02796bf4f1841630355c82b5a69933b1d50eb15c7fa37dabe8/numba-0.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:44e3aa6228039992f058f5ebfcfd372c83798e9464297bdad8cc79febcf7891e", size = 2745552 }, + { url = "https://files.pythonhosted.org/packages/22/76/501ea2c07c089ef1386868f33dff2978f43f51b854e34397b20fc55e0a58/numba-0.62.1-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:b72489ba8411cc9fdcaa2458d8f7677751e94f0109eeb53e5becfdc818c64afb", size = 2685766 }, + { url = "https://files.pythonhosted.org/packages/80/68/444986ed95350c0611d5c7b46828411c222ce41a0c76707c36425d27ce29/numba-0.62.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:44a1412095534a26fb5da2717bc755b57da5f3053965128fe3dc286652cc6a92", size = 2688741 }, + { url = "https://files.pythonhosted.org/packages/78/7e/bf2e3634993d57f95305c7cee4c9c6cb3c9c78404ee7b49569a0dfecfe33/numba-0.62.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8c9460b9e936c5bd2f0570e20a0a5909ee6e8b694fd958b210e3bde3a6dba2d7", size = 3804576 }, + { url = "https://files.pythonhosted.org/packages/e8/b6/8a1723fff71f63bbb1354bdc60a1513a068acc0f5322f58da6f022d20247/numba-0.62.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:728f91a874192df22d74e3fd42c12900b7ce7190b1aad3574c6c61b08313e4c5", size = 3503367 }, + { url = "https://files.pythonhosted.org/packages/9c/ec/9d414e7a80d6d1dc4af0e07c6bfe293ce0b04ea4d0ed6c45dad9bd6e72eb/numba-0.62.1-cp313-cp313-win_amd64.whl", hash = "sha256:bbf3f88b461514287df66bc8d0307e949b09f2b6f67da92265094e8fa1282dd8", size = 2745529 }, +] + [[package]] name = "numpy" version = "2.1.3" @@ -2123,7 +2194,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2135,7 +2206,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, @@ -2158,9 +2229,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, @@ -2173,7 +2244,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, @@ -3691,7 +3762,7 @@ name = "triton" version = "3.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-gpu') or (extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, + { name = "filelock", marker = "(python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-llm-foundry-gpu') or (python_full_version >= '3.13' and platform_machine != 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (platform_machine == 'aarch64' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu') or (sys_platform != 'linux' and extra == 'extra-11-llm-foundry-cpu' and extra == 'extra-11-llm-foundry-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 },