Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions custom_model_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import modal
import sys
from modal import Image, App, Secret, Volume

import pathlib, datetime

PYTHON_PATH = "/opt/conda/envs/llm-foundry/bin/python"

# command line arguments
TRAINING_GPU = os.environ.get("MODAL_GPU", "L4")
TRAIN_YAML = os.environ.get("TRAIN_YAML", "")
IS_PEFT = os.environ.get("IS_PEFT", "True")
IS_PEFT = IS_PEFT in ("True", "true")

OUTPUT_PRECISION = os.environ.get("OUTPUT_PRECISION", "bf16")

# defaults --- make sure your Modal Volumes are titled accordingly
DATASET_BASE_PATH = "/datasets"
DATASETS_VOLUME = Volume.from_name("lrg-datasets", create_if_missing=True)
DATASETS_VOLUME_MOUNT_PATH = pathlib.Path("/datasets")
MODEL_CHECKPOINT_VOLUME = Volume.from_name("lrg-model-checkpoints", create_if_missing=True)
MODEL_CHECKPOINT_VOLUME_MOUNT_PATH = pathlib.Path("/model-checkpoints")

app = App("custom-llama-training")

# Build image from local Dockerfile
image = Image.from_dockerfile("Dockerfile", gpu='L4')
image = image.add_local_file(TRAIN_YAML, f"/llm-foundry/scripts/train/yamls/pretrain/{TRAIN_YAML}")
# image = image.add_local_file("train.py", "/llm-foundry/llmfoundry/command_utils/train.py")


@app.function(gpu=TRAINING_GPU, image=image, timeout=12*3600, secrets=[Secret.from_name("LRG")], # pyright: ignore[reportUntypedFunctionDecorator]
volumes={MODEL_CHECKPOINT_VOLUME_MOUNT_PATH: MODEL_CHECKPOINT_VOLUME,
DATASETS_VOLUME_MOUNT_PATH: DATASETS_VOLUME},
max_containers=1)
def _train(yaml_path):
import os
import sys
import logging
from pathlib import Path

project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))

from llmfoundry.models.llama.register import register_custom_llama_model
from llmfoundry.command_utils.train import train
from omegaconf import OmegaConf

# import text_generation_callback # type: ignore

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)

"""Main PEFT training function."""
logger.info("Registering custom SmolLM2-135M model...")
register_custom_llama_model()
logger.info("Custom model registered successfully!")

config_path = f"scripts/train/yamls/pretrain/{yaml_path}"
logger.info(f"Loading configuration from: {config_path}")

if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found: {config_path}")

config = OmegaConf.load(config_path)

save_folder = "lrg-model-checkpoints/hf_smollm2_135m_peft"
config.save_folder = save_folder
os.makedirs(save_folder, exist_ok=True)
logger.info(f"PEFT model checkpoints will be saved to: {save_folder}")

dataset_local = config.variables.data_local
dataset_remote = getattr(config.variables, 'data_remote', None)
if dataset_remote and str(dataset_remote).strip():
os.makedirs(dataset_local, exist_ok=True)
logger.info(
f"Streaming dataset from remote: {dataset_remote} with local cache: {dataset_local}")
else:
if not os.path.exists(dataset_local):
logger.warning(f"Dataset not found at: {dataset_local}")
return
logger.info(f"Using local dataset at: {dataset_local}")

if IS_PEFT and hasattr(config.model, 'peft_config') and config.model.peft_config:
peft_config = config.model.peft_config
logger.info("PEFT Configuration:")
logger.info(f" - Type: {peft_config.peft_type}")
logger.info(f" - Rank (r): {peft_config.r}")
logger.info(f" - Alpha: {peft_config.lora_alpha}")
logger.info(f" - Dropout: {peft_config.lora_dropout}")
logger.info(f" - Target modules: {peft_config.target_modules}")
logger.info(f" - Use RSLora: {peft_config.get('use_rslora', False)}")
logger.info(f" - Use DoRA: {peft_config.get('use_dora', False)}")

logger.info("Starting PEFT training...")
try:
trainer = train(config)
logger.info("PEFT training completed successfully!")
logger.info(f"PEFT adapters saved to: {save_folder}")
del trainer
return "Training completed successfully"
except Exception as e:
logger.error(f"PEFT training failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise
else:
logger.warning("No PEFT configuration found!")
logger.info("Starting Full finetuning training...")
try:
trainer = train(config)
logger.info("Full training completed successfully!")
del trainer
return "Training completed successfully"
except Exception as e:
logger.error(f"Full training failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise
return

@app.local_entrypoint()
def main():
_train.remote(yaml_path=TRAIN_YAML)
9 changes: 9 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
)
from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.callbacks.text_generation_callback import TextGenerationCallback
from llmfoundry.callbacks.batch_inspection_callback import BatchInspectionCallback
from llmfoundry.callbacks.packing_efficiency_callback import PackingEfficiency
from llmfoundry.registry import callbacks, callbacks_with_config

callbacks.register('system_metrics_monitor', func=SystemMetricsMonitor)
Expand All @@ -65,6 +68,9 @@
callbacks.register('nan_monitor', func=NaNMonitor)
callbacks.register('kill_loss_spike', func=KillLossSpike)
callbacks.register('load_checkpoint', func=LoadCheckpoint)
callbacks.register('text_generation', func=TextGenerationCallback)
callbacks.register('batch_inspection', func=BatchInspectionCallback)
callbacks.register('packing_efficiency', func=PackingEfficiency)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -83,4 +89,7 @@
'CurriculumLearning',
'LossPerpVsContextLengthLogger',
'KillLossSpike',
'TextGenerationCallback',
'BatchInspectionCallback',
'PackingEfficiency',
]
210 changes: 210 additions & 0 deletions llmfoundry/callbacks/batch_inspection_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#!/usr/bin/env python3
"""Callback for inspecting batch data passed to the model's forward method."""

import logging
import torch
from typing import Any, Dict
from composer.core import Callback, State
from composer.loggers import Logger

logger = logging.getLogger(__name__)

class BatchInspectionCallback(Callback):
"""Callback that logs detailed information about batches passed to the model."""

def __init__(
self,
log_frequency: int = 10,
sample_size: int = 3,
log_to_console: bool = True,
log_to_wandb: bool = True,
):
"""Initialize the callback.

Args:
log_frequency: Log batch info every N batches
sample_size: Number of sample values to show from each tensor
log_to_console: Whether to print to console
log_to_wandb: Whether to log to wandb
"""
self.log_frequency = log_frequency
self.sample_size = sample_size
self.log_to_console = log_to_console
self.log_to_wandb = log_to_wandb
self.batch_count = 0
self.original_forward = None

def _inspect_tensor(self, tensor: torch.Tensor, name: str) -> dict[str, Any]:
"""Extract detailed information from a tensor."""
info = {
'shape': list(tensor.shape),
'dtype': str(tensor.dtype),
'device': str(tensor.device),
'requires_grad': tensor.requires_grad,
}

# Get sample values
if tensor.numel() > 0:
flattened = tensor.flatten()

# For input_ids and labels, show first 10 and last 10
if name in ['input_ids', 'labels']:
first_10 = flattened[:10].tolist() if flattened.numel() >= 10 else flattened.tolist()
last_10 = flattened[-10:].tolist() if flattened.numel() >= 10 else []
info['first_10'] = first_10
info['last_10'] = last_10
elif name in ['attention_mask']:
info['sum'] = tensor.sum().item()
info['numel'] = tensor.numel()
else:
# For other tensors, use original sampling method
sample_indices = torch.linspace(0, flattened.numel()-1, min(self.sample_size, flattened.numel()), dtype=torch.long)
samples = flattened[sample_indices].tolist()
info['samples'] = samples

# Basic statistics for numeric tensors
if tensor.dtype in [torch.float16, torch.float32, torch.float64, torch.int32, torch.int64]:
info['min'] = tensor.min().item()
info['max'] = tensor.max().item()
info['mean'] = tensor.float().mean().item()

return info

def _inspect_batch(self, batch: Any, prefix: str = "") -> dict[str, Any]:
"""Recursively inspect batch structure."""
batch_info = {}

if isinstance(batch, torch.Tensor):
return self._inspect_tensor(batch, prefix)
elif isinstance(batch, dict):
batch_info['type'] = 'dict'
batch_info['keys'] = list(batch.keys())
batch_info['contents'] = {}
for key, value in batch.items():
batch_info['contents'][key] = self._inspect_batch(value, f"{prefix}.{key}" if prefix else key)
elif isinstance(batch, (list, tuple)):
batch_info['type'] = type(batch).__name__
batch_info['length'] = len(batch)
batch_info['contents'] = {}
for i, item in enumerate(batch):
batch_info['contents'][f'item_{i}'] = self._inspect_batch(item, f"{prefix}[{i}]" if prefix else f"item_{i}")
else:
batch_info = {
'type': type(batch).__name__,
'value': str(batch)[:100], # Truncate long strings
}

return batch_info

def _log_batch_info(self, batch_info: dict[str, Any], state: State, logger: Logger):
"""Log batch information to console and wandb."""
if self.log_to_console:
print(f"\n{'='*80}")
if state and state.timestamp:
step_info = f"Step {state.timestamp.batch}"
else:
step_info = "Evaluation"
print(f"BATCH INSPECTION - {step_info}")
print(f"{'='*80}")
self._print_batch_info(batch_info)
print(f"{'='*80}\n")

if self.log_to_wandb:
# Flatten the batch info for wandb logging
flat_info = self._flatten_batch_info(batch_info)
wandb_metrics = {}
for key, value in flat_info.items():
if isinstance(value, (int, float, str, bool)):
wandb_metrics[f"batch_inspection/{key}"] = value

if wandb_metrics and hasattr(logger, 'log_metrics'):
logger.log_metrics(wandb_metrics)

def _print_batch_info(self, info: Any, indent: int = 0):
"""Recursively print batch information."""
prefix = " " * indent

if isinstance(info, dict):
if 'type' in info and info['type'] in ['dict', 'list', 'tuple']:
print(f"{prefix}Type: {info['type']}")
if 'keys' in info:
print(f"{prefix}Keys: {info['keys']}")
if 'length' in info:
print(f"{prefix}Length: {info['length']}")
if 'contents' in info:
for key, value in info['contents'].items():
print(f"{prefix}{key}:")
self._print_batch_info(value, indent + 1)
else:
# Tensor info
for key, value in info.items():
if key in ['samples', 'first_10', 'last_10']:
print(f"{prefix}{key}: {value}")
else:
print(f"{prefix}{key}: {value}")
else:
print(f"{prefix}{info}")

def _flatten_batch_info(self, info: Any, prefix: str = "") -> dict[str, Any]:
"""Flatten nested batch info for wandb logging."""
flat = {}

if isinstance(info, dict):
if 'shape' in info: # Tensor info
flat[f"{prefix}_shape"] = str(info['shape'])
flat[f"{prefix}_dtype"] = info['dtype']
flat[f"{prefix}_device"] = info['device']
if 'min' in info:
flat[f"{prefix}_min"] = info['min']
if 'max' in info:
flat[f"{prefix}_max"] = info['max']
if 'mean' in info:
flat[f"{prefix}_mean"] = info['mean']
elif 'contents' in info:
for key, value in info['contents'].items():
new_prefix = f"{prefix}_{key}" if prefix else key
flat.update(self._flatten_batch_info(value, new_prefix))

return flat

def _create_forward_wrapper(self, original_forward):
"""Create a wrapper around the model's forward method."""
def forward_wrapper(*args, **kwargs):
# Inspect the batch (typically the first argument)
if args and self.batch_count % self.log_frequency == 0:
batch = args[0] if len(args) > 0 else kwargs
batch_info = self._inspect_batch(batch)

# We'll store this info to log it in the next callback event
if hasattr(self, '_current_state') and hasattr(self, '_current_logger'):
self._log_batch_info(batch_info, self._current_state, self._current_logger)

self.batch_count += 1
return original_forward(*args, **kwargs)

return forward_wrapper

def batch_start(self, state: State, logger: Logger) -> None:
"""Called at the start of each batch."""
# Store state and logger for use in forward wrapper
self._current_state = state
self._current_logger = logger

# Wrap the model's forward method if not already wrapped
if self.original_forward is None:
model = state.model
if hasattr(model, 'forward'):
self.original_forward = model.forward
model.forward = self._create_forward_wrapper(self.original_forward)

def batch_end(self, state: State, logger: Logger) -> None:
"""Called at the end of each batch."""
# Clean up references
self._current_state = None
self._current_logger = None

def fit_end(self, state: State, logger: Logger) -> None:
"""Restore original forward method when training ends."""
if self.original_forward is not None:
state.model.forward = self.original_forward
self.original_forward = None
27 changes: 27 additions & 0 deletions llmfoundry/callbacks/packing_efficiency_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 onwards Answer.AI, LightOn, and contributors
# License: Apache-2.0

from composer.core import Callback, State
from composer.loggers import Logger

__all__ = ["PackingEfficiency"]


class PackingEfficiency(Callback):
"""Records the packing efficiency for each batch."""

def __init__(self, log_interval: int = 100):
self.log_interval = log_interval

def after_dataloader(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval != 0:
return
logger.log_metrics(
{
"trainer/packing_efficiency": self._packing_efficiency(state),
},
)

def _packing_efficiency(self, state: State) -> float:
return state.batch["attention_mask"].sum().item() / state.batch["attention_mask"].numel()

Loading