From abaa96ec0096e095b3195f7c0f1b3ffed945b281 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Sun, 27 Apr 2025 21:55:36 +0800 Subject: [PATCH 1/5] feat(example): add fsdp_main.py --- examples/accelerate_configs/zero2.yaml | 21 ++ examples/fsdp_main.py | 197 +++++++++++ examples/fsdp_toy_train.py | 468 +++++++++++++++++++++++++ examples/toy_train.py | 45 ++- 4 files changed, 712 insertions(+), 19 deletions(-) create mode 100644 examples/accelerate_configs/zero2.yaml create mode 100644 examples/fsdp_main.py create mode 100644 examples/fsdp_toy_train.py diff --git a/examples/accelerate_configs/zero2.yaml b/examples/accelerate_configs/zero2.yaml new file mode 100644 index 0000000..f4df9d4 --- /dev/null +++ b/examples/accelerate_configs/zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/examples/fsdp_main.py b/examples/fsdp_main.py new file mode 100644 index 0000000..ec636d8 --- /dev/null +++ b/examples/fsdp_main.py @@ -0,0 +1,197 @@ +# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py +import os +import argparse +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + + +from torch.optim.lr_scheduler import StepLR + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + BackwardPrefetch, +) +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + ddp_loss = torch.zeros(2).to(rank) + if sampler: + sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target, reduction='sum') + loss.backward() + optimizer.step() + ddp_loss[0] += loss.item() + ddp_loss[1] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) + +def test(model, rank, world_size, test_loader): + model.eval() + correct = 0 + ddp_loss = torch.zeros(3).to(rank) + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(rank), target.to(rank) + output = model(data) + ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() + ddp_loss[2] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + + if rank == 0: + test_loss = ddp_loss[0] / ddp_loss[2] + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, int(ddp_loss[1]), int(ddp_loss[2]), + 100. * ddp_loss[1] / ddp_loss[2])) + +def fsdp_main(rank, world_size, args): + print((rank, world_size)) + setup(rank, world_size) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + + sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, + 'pin_memory': True, + 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + my_auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=100 + ) + torch.cuda.set_device(rank) + + + init_start_event = torch.cuda.Event(enable_timing=True) + init_end_event = torch.cuda.Event(enable_timing=True) + + model = Net().to(rank) + + model = DDP(model) + + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + init_start_event.record() + for epoch in range(1, args.epochs + 1): + train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) + test(model, rank, world_size, test_loader) + scheduler.step() + + init_end_event.record() + + if rank == 0: + init_end_event.synchronize() + print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec") + print(f"{model}") + + if args.save_model: + # use a barrier to make sure training is done on all ranks + dist.barrier() + states = model.state_dict() + if rank == 0: + torch.save(states, "mnist_cnn.pt") + + cleanup() + +if __name__ == '__main__': + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + + torch.manual_seed(args.seed) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, + args=(WORLD_SIZE, args), + nprocs=WORLD_SIZE, + join=True) \ No newline at end of file diff --git a/examples/fsdp_toy_train.py b/examples/fsdp_toy_train.py new file mode 100644 index 0000000..35c1417 --- /dev/null +++ b/examples/fsdp_toy_train.py @@ -0,0 +1,468 @@ +import os +import math +import torch +from loguru import logger +from datasets import load_dataset +import argparse +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers import ( + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Tokenizer, + get_cosine_schedule_with_warmup, +) +from tqdm import tqdm +from typing import Optional +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.tensor import distribute_tensor +from torch.distributed.tensor import DTensor +import torch.distributed as dist + +def to_dist(x, from_local=False, **meta): + if from_local: + return DTensor.from_local( + x, + device_mesh=meta["device_mesh"], + placements=meta["placements"], + shape=meta["shape"], + stride=meta["stride"], + ) + else: + return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) + +def to_local(x, keep_sharded=False): + if isinstance(x, DTensor): + meta = dict( + device_mesh=x.device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + if keep_sharded: + return x.to_local(), meta + else: + return x.full_tensor(), meta + + return x, None + +class Muon(torch.optim.Optimizer): + def __init__( + self, + muon_params=None, + lr=1e-3, + weight_decay=0.1, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + betas=(0.9, 0.95), + eps=1e-8, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + capturable: bool = False, + differentiable: bool = False, + fused: Optional[bool] = None, + bias_correction=True, + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + foreach=foreach, + maximize=maximize, + capturable=capturable, + differentiable=differentiable, + fused=fused, + bias_correction=bias_correction, + ) + + params = [] + + muon_params = list(muon_params) if muon_params is not None else [] + params.extend(muon_params) + + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + + super().__init__(params, defaults) + + # sort params into those for which we will use muon and those for which we will not + for p in muon_params: + # for p in group["params"]: + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # for p in group["params"]: + self.state[p]["use_muon"] = False + + @staticmethod + def adjust_lr_for_muon(lr, param_shape): + A, B = param_shape[:2] + + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + + return adjusted_lr + + @staticmethod + def _update_adamw( + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + ): + grad = grad.to(data.dtype) + + # Decay the first and second moment running average coefficient + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.lerp_(grad.square(), 1 - beta2) + + grad = exp_avg / (eps + exp_avg_sq.sqrt()) + + scale = bias_correction1 / bias_correction2**0.5 + + if weight_decay != 0: + data.mul_(1 - lr * weight_decay) + + data.add_(grad, alpha=-lr / scale) + + @torch.no_grad() + def step(self, closure=None, **kwargs): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params = [p for p in group["params"] if self.state[p]["use_muon"]] + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(group["momentum"]).add_(g) + g = g.add(buf, alpha=group["momentum"]) if group["nesterov"] else buf + + meta = None + if isinstance(g, DTensor): + g, meta = to_local(g, keep_sharded=False) + + # gives NaNs when done with DTensor, instead of throwing a typical op not supported error, quite sneaky + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + if meta is not None: + g = to_dist(g, **meta) + + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + g = g.view_as(p.data).type_as(p.data) + + # apply weight decay + if group["weight_decay"] != 0: + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + # apply lr and update + adjusted_lr = self.adjust_lr_for_muon(group["lr"], p.shape) + p.data.add_(g, alpha=-adjusted_lr) + + # adamw + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + beta1, beta2 = group["betas"] + + for p in params: + g = p.grad + if g is None: + continue + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + # gradient momentums + state["exp_avg"] = torch.zeros_like(p, device=p.device) + # gradient variances + state["exp_avg_sq"] = torch.zeros_like(p, device=p.device) + + state["step"] += 1 + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + self._update_adamw( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + ) + return loss + +class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + print('loading tokenized data') + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) + ) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not ( + p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ) + ] + + return Muon( + lr=lr, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + return parser.parse_args() + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) + ) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not ( + p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ) + ] + + return Muon( + lr=lr, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + +def get_train_loader(dataset_name, rank, world_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) + tokenizer = Qwen2Tokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", trust_remote_code=True + ) + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + sampler = DistributedSampler(dataset=train_dataset, rank=rank, num_replicas=world_size, shuffle=True) + kwargs = {'batch_size': 16, 'sampler': sampler} + train_loader = DataLoader(train_dataset, **kwargs) + return train_loader + +def fsdp_main(rank, world_size, args): + print((rank, world_size)) + setup(rank, world_size) + torch.cuda.set_device(rank) + # load model + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + optimizer = get_optimizer(args.optimizer, model, lr=args.lr, wd=args.wd) + model = model.to(rank) + model = DDP(model) + model.train() + + train_loader = get_train_loader(args.dataset, rank, world_size) + epoch = 1 + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + + ddp_loss = torch.zeros(2).to(rank) + for epoch in range(epoch): + + for step, batch in enumerate(train_loader): + batch = batch.to(rank) + input_ids = batch + output = model(input_ids=input_ids, labels=input_ids) + optimizer.zero_grad() + loss = output.loss + loss.backward() + optimizer.step() + ddp_loss[0] += loss.item() + ddp_loss[1] += len(batch) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + logger.info('Train Epoch: {} \tLoss: {:.6f} \t Batch: {}'.format(epoch, ddp_loss[0] / ddp_loss[1], len(batch))) + lr_scheduler.step() + + if args.save_model: + # use a barrier to make sure training is done on all ranks + dist.barrier() + states = model.state_dict() + if rank == 0: + torch.save(states, "mnist_cnn.pt") + + cleanup() + +if __name__ == '__main__': + # Training settings + args = parse_args() + + # fsdp_main(0, 1, args) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, + args=(WORLD_SIZE, args), + nprocs=WORLD_SIZE, + join=True) \ No newline at end of file diff --git a/examples/toy_train.py b/examples/toy_train.py index fa1e339..a883beb 100644 --- a/examples/toy_train.py +++ b/examples/toy_train.py @@ -131,6 +131,8 @@ def __init__( params.extend(adamw_params) super().__init__(params, defaults) # Sort parameters into those for which we will use Muon, and those for which we will not + import pdb + pdb.set_trace() for p in muon_params: # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer assert p.ndim == 2, p.ndim @@ -239,20 +241,7 @@ def step(self, closure=None): return loss -def get_model_and_dataloader(model_name, dataset_name, hidden_size): - name2path = { - "openwebtext-100k": "Elriggs/openwebtext-100k", - } - train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) - if model_name == "qwen": - tokenizer = Qwen2Tokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B", trust_remote_code=True - ) - else: - assert 0, f"model {model_name} not supported" - train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) - +def get_model(model_name, dataset_name, hidden_size): if model_name == "qwen": config = Qwen2Config( attention_dropout=0.0, @@ -281,7 +270,24 @@ def get_model_and_dataloader(model_name, dataset_name, hidden_size): model = Qwen2ForCausalLM(config) else: assert 0, f"model {model_name} not supported" - return model, train_loader + return model + + +def get_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", trust_remote_code=True + ) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + + return train_loader def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): @@ -318,7 +324,7 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--optimizer", type=str, default="muon") parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--wd", type=float, default=0.1) parser.add_argument("--dataset", type=str, default="openwebtext-100k") @@ -326,9 +332,8 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): args = parser.parse_args() logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") - model, train_loader = get_model_and_dataloader( - args.model, args.dataset, args.hidden_size - ) + model = get_model(args.model, args.dataset, args.hidden_size) + optimizer = get_optimizer( args.optimizer, model, lr=args.lr ) @@ -338,6 +343,8 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): model.train() epoch = 1 + train_loader = get_dataloader(args.model, args.dataset, args.hidden_size) + lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From 5c1d98fe6b9d6bdc5d746407aabfdedb3598db85 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 28 Apr 2025 17:46:12 +0800 Subject: [PATCH 2/5] feat(examples): add toy fsdp train --- examples/dustbin.py | 179 ++++++++++++++++ examples/fsdp_main.py | 3 +- examples/fsdp_toy_train.py | 358 ++++++++++++++++---------------- examples/toy_train.py | 37 ++-- examples/toy_train_fsdp.py | 406 ++++++++++++++++++++++++++++++++++++ examples/toy_train_v1.py | 368 +++++++++++++++++++++++++++++++++ examples/toy_train_v2.py | 407 +++++++++++++++++++++++++++++++++++++ 7 files changed, 1557 insertions(+), 201 deletions(-) create mode 100644 examples/dustbin.py create mode 100644 examples/toy_train_fsdp.py create mode 100644 examples/toy_train_v1.py create mode 100644 examples/toy_train_v2.py diff --git a/examples/dustbin.py b/examples/dustbin.py new file mode 100644 index 0000000..61edc09 --- /dev/null +++ b/examples/dustbin.py @@ -0,0 +1,179 @@ + +# class Muon(torch.optim.Optimizer): +# def __init__( +# self, +# muon_params=None, +# lr=1e-3, +# weight_decay=0.1, +# momentum=0.95, +# nesterov=True, +# ns_steps=5, +# adamw_params=None, +# betas=(0.9, 0.95), +# eps=1e-8, +# *, +# maximize: bool = False, +# foreach: Optional[bool] = None, +# capturable: bool = False, +# differentiable: bool = False, +# fused: Optional[bool] = None, +# bias_correction=True, +# ): +# defaults = dict( +# lr=lr, +# betas=betas, +# eps=eps, +# weight_decay=weight_decay, +# momentum=momentum, +# nesterov=nesterov, +# ns_steps=ns_steps, +# foreach=foreach, +# maximize=maximize, +# capturable=capturable, +# differentiable=differentiable, +# fused=fused, +# bias_correction=bias_correction, +# ) + +# params = [] + +# muon_params = list(muon_params) if muon_params is not None else [] +# params.extend(muon_params) + +# adamw_params = list(adamw_params) if adamw_params is not None else [] +# params.extend(adamw_params) + +# super().__init__(params, defaults) + +# # sort params into those for which we will use muon and those for which we will not +# for p in muon_params: +# # for p in group["params"]: +# assert p.ndim == 2, p.ndim +# self.state[p]["use_muon"] = True +# for p in adamw_params: +# # for p in group["params"]: +# self.state[p]["use_muon"] = False + +# @staticmethod +# def adjust_lr_for_muon(lr, param_shape): +# A, B = param_shape[:2] + +# adjusted_ratio = 0.2 * math.sqrt(max(A, B)) +# adjusted_lr = lr * adjusted_ratio + +# return adjusted_lr + +# @staticmethod +# def _update_adamw( +# data, +# grad, +# exp_avg, +# exp_avg_sq, +# lr, +# beta1, +# beta2, +# eps, +# weight_decay, +# bias_correction1, +# bias_correction2, +# ): +# grad = grad.to(data.dtype) + +# # Decay the first and second moment running average coefficient +# exp_avg.lerp_(grad, 1 - beta1) +# exp_avg_sq.lerp_(grad.square(), 1 - beta2) + +# grad = exp_avg / (eps + exp_avg_sq.sqrt()) + +# scale = bias_correction1 / bias_correction2**0.5 + +# if weight_decay != 0: +# data.mul_(1 - lr * weight_decay) + +# data.add_(grad, alpha=-lr / scale) + +# @torch.no_grad() +# def step(self, closure=None, **kwargs): +# loss = None +# if closure is not None: +# with torch.enable_grad(): +# loss = closure() + +# for group in self.param_groups: +# params = [p for p in group["params"] if self.state[p]["use_muon"]] + +# for p in params: +# g = p.grad +# if g is None: +# continue +# if g.ndim > 2: +# g = g.view(g.size(0), -1) +# assert g is not None + +# # calc update +# state = self.state[p] + +# if "momentum_buffer" not in state: +# state["momentum_buffer"] = torch.zeros_like(g) +# buf = state["momentum_buffer"] +# buf.mul_(group["momentum"]).add_(g) +# g = g.add(buf, alpha=group["momentum"]) if group["nesterov"] else buf + +# meta = None +# if isinstance(g, DTensor): +# g, meta = to_local(g, keep_sharded=False) + +# # gives NaNs when done with DTensor, instead of throwing a typical op not supported error, quite sneaky +# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + +# if meta is not None: +# g = to_dist(g, **meta) + +# g *= max(1, g.size(0) / g.size(1)) ** 0.5 +# g = g.view_as(p.data).type_as(p.data) + +# # apply weight decay +# if group["weight_decay"] != 0: +# p.data.mul_(1 - group["lr"] * group["weight_decay"]) + +# # apply lr and update +# adjusted_lr = self.adjust_lr_for_muon(group["lr"], p.shape) +# p.data.add_(g, alpha=-adjusted_lr) + +# # adamw +# params = [p for p in group["params"] if not self.state[p]["use_muon"]] +# beta1, beta2 = group["betas"] + +# for p in params: +# g = p.grad +# if g is None: +# continue + +# state = self.state[p] + +# if "step" not in state: +# state["step"] = 0 +# # gradient momentums +# state["exp_avg"] = torch.zeros_like(p, device=p.device) +# # gradient variances +# state["exp_avg_sq"] = torch.zeros_like(p, device=p.device) + +# state["step"] += 1 + +# bias_correction1 = 1 - beta1 ** state["step"] +# bias_correction2 = 1 - beta2 ** state["step"] + +# self._update_adamw( +# p.data, +# p.grad.data, +# state["exp_avg"], +# state["exp_avg_sq"], +# group["lr"], +# beta1, +# beta2, +# group["eps"], +# group["weight_decay"], +# bias_correction1, +# bias_correction2, +# ) +# return loss \ No newline at end of file diff --git a/examples/fsdp_main.py b/examples/fsdp_main.py index ec636d8..61930ce 100644 --- a/examples/fsdp_main.py +++ b/examples/fsdp_main.py @@ -139,8 +139,7 @@ def fsdp_main(rank, world_size, args): init_end_event = torch.cuda.Event(enable_timing=True) model = Net().to(rank) - - model = DDP(model) + model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) diff --git a/examples/fsdp_toy_train.py b/examples/fsdp_toy_train.py index 35c1417..1889fd3 100644 --- a/examples/fsdp_toy_train.py +++ b/examples/fsdp_toy_train.py @@ -7,6 +7,7 @@ import torch.multiprocessing as mp import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +import functools from torch.utils.data.distributed import DistributedSampler from transformers import ( Qwen2Config, @@ -19,7 +20,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.tensor import distribute_tensor from torch.distributed.tensor import DTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + import torch.distributed as dist +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + def to_dist(x, from_local=False, **meta): if from_local: @@ -31,7 +40,10 @@ def to_dist(x, from_local=False, **meta): stride=meta["stride"], ) else: - return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) + return distribute_tensor(x, + device_mesh=meta["device_mesh"], + placements=meta["placements"]) + def to_local(x, keep_sharded=False): if isinstance(x, DTensor): @@ -48,110 +60,105 @@ def to_local(x, keep_sharded=False): return x, None + class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + def __init__( - self, - muon_params=None, - lr=1e-3, - weight_decay=0.1, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_params=None, - betas=(0.9, 0.95), - eps=1e-8, - *, - maximize: bool = False, - foreach: Optional[bool] = None, - capturable: bool = False, - differentiable: bool = False, - fused: Optional[bool] = None, - bias_correction=True, + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, ): + defaults = dict( lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, + wd=wd, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, - foreach=foreach, - maximize=maximize, - capturable=capturable, - differentiable=differentiable, - fused=fused, - bias_correction=bias_correction, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, ) - params = [] - - muon_params = list(muon_params) if muon_params is not None else [] - params.extend(muon_params) - + params = list(muon_params) adamw_params = list(adamw_params) if adamw_params is not None else [] params.extend(adamw_params) - super().__init__(params, defaults) - - # sort params into those for which we will use muon and those for which we will not + # Sort parameters into those for which we will use Muon, and those for which we will not for p in muon_params: - # for p in group["params"]: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer assert p.ndim == 2, p.ndim self.state[p]["use_muon"] = True for p in adamw_params: - # for p in group["params"]: + # Do not use Muon for parameters in adamw_params self.state[p]["use_muon"] = False - @staticmethod - def adjust_lr_for_muon(lr, param_shape): + def adjust_lr_for_muon(self, lr, param_shape): A, B = param_shape[:2] - + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper adjusted_ratio = 0.2 * math.sqrt(max(A, B)) adjusted_lr = lr * adjusted_ratio - return adjusted_lr - @staticmethod - def _update_adamw( - data, - grad, - exp_avg, - exp_avg_sq, - lr, - beta1, - beta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - ): - grad = grad.to(data.dtype) - - # Decay the first and second moment running average coefficient - exp_avg.lerp_(grad, 1 - beta1) - exp_avg_sq.lerp_(grad.square(), 1 - beta2) - - grad = exp_avg / (eps + exp_avg_sq.sqrt()) + def step(self, closure=None): + """Perform a single optimization step. - scale = bias_correction1 / bias_correction2**0.5 - - if weight_decay != 0: - data.mul_(1 - lr * weight_decay) - - data.add_(grad, alpha=-lr / scale) - - @torch.no_grad() - def step(self, closure=None, **kwargs): + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: + + ############################ + # Muon # + ############################ + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + # generate weight updates in distributed fashion for p in params: + # sanity check g = p.grad if g is None: continue @@ -161,73 +168,66 @@ def step(self, closure=None, **kwargs): # calc update state = self.state[p] - if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf = state["momentum_buffer"] - buf.mul_(group["momentum"]).add_(g) - g = g.add(buf, alpha=group["momentum"]) if group["nesterov"] else buf - - meta = None - if isinstance(g, DTensor): - g, meta = to_local(g, keep_sharded=False) - - # gives NaNs when done with DTensor, instead of throwing a typical op not supported error, quite sneaky - g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - - if meta is not None: - g = to_dist(g, **meta) + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - g *= max(1, g.size(0) / g.size(1)) ** 0.5 - g = g.view_as(p.data).type_as(p.data) + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) # apply weight decay - if group["weight_decay"] != 0: - p.data.mul_(1 - group["lr"] * group["weight_decay"]) + p.data.mul_(1 - lr * wd) - # apply lr and update - adjusted_lr = self.adjust_lr_for_muon(group["lr"], p.shape) - p.data.add_(g, alpha=-adjusted_lr) + # apply update + p.data.add_(u, alpha=-adjusted_lr) - # adamw - params = [p for p in group["params"] if not self.state[p]["use_muon"]] - beta1, beta2 = group["betas"] + ############################ + # AdamW backup # + ############################ + + params = [ + p for p in group["params"] if not self.state[p]["use_muon"] + ] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] for p in params: g = p.grad if g is None: continue - state = self.state[p] - if "step" not in state: state["step"] = 0 - # gradient momentums - state["exp_avg"] = torch.zeros_like(p, device=p.device) - # gradient variances - state["exp_avg_sq"] = torch.zeros_like(p, device=p.device) - + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - self._update_adamw( - p.data, - p.grad.data, - state["exp_avg"], - state["exp_avg_sq"], - group["lr"], - beta1, - beta2, - group["eps"], - group["weight_decay"], - bias_correction1, - bias_correction2, - ) return loss + class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): self.dataset_name = dataset_name self.dataset = dataset @@ -281,9 +281,8 @@ def zeropower_via_newtonschulz5(G, steps): # Perform the NS iterations for _ in range(steps): A = X @ X.T - B = ( - b * A + c * A @ A - ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = (b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X if G.size(0) > G.size(1): @@ -291,36 +290,8 @@ def zeropower_via_newtonschulz5(G, steps): return X -def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): - if optimizer_name == "adamw": - return torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) - ) - elif optimizer_name == "muon": - muon_params = [ - p - for name, p in model.named_parameters() - if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ] - adamw_params = [ - p - for name, p in model.named_parameters() - if not ( - p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ) - ] - - return Muon( - lr=lr, - muon_params=muon_params, - adamw_params=adamw_params, - ) - else: - assert 0, "optimizer not supported" - def parse_args(): import argparse - parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="qwen") parser.add_argument("--lr", type=float, default=1e-3) @@ -328,10 +299,13 @@ def parse_args(): parser.add_argument("--dataset", type=str, default="openwebtext-100k") parser.add_argument("--hidden_size", type=int, default=1024) parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument('--save-model', action='store_true', default=False, + parser.add_argument('--save-model', + action='store_true', + default=False, help='For Saving the current Model') return parser.parse_args() + def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' @@ -339,53 +313,66 @@ def setup(rank, world_size): # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) + def cleanup(): dist.destroy_process_group() + def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): if optimizer_name == "adamw": - return torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) - ) + return torch.optim.AdamW(model.parameters(), + lr=lr, + weight_decay=wd, + betas=(0.9, 0.95)) elif optimizer_name == "muon": muon_params = [ - p - for name, p in model.named_parameters() - if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + p for name, p in model.named_parameters() if p.ndim >= 2 + and "embed_tokens" not in name and "lm_head" not in name ] adamw_params = [ - p - for name, p in model.named_parameters() - if not ( - p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ) + p for name, p in model.named_parameters() + if not (p.ndim >= 2 and "embed_tokens" not in name + and "lm_head" not in name) ] return Muon( lr=lr, + wd=wd, muon_params=muon_params, adamw_params=adamw_params, ) else: assert 0, "optimizer not supported" + def get_train_loader(dataset_name, rank, world_size): name2path = { "openwebtext-100k": "Elriggs/openwebtext-100k", } - train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) - tokenizer = Qwen2Tokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B", trust_remote_code=True - ) + train_dataset = load_dataset(name2path[dataset_name], + trust_remote_code=True) + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", + trust_remote_code=True) train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - sampler = DistributedSampler(dataset=train_dataset, rank=rank, num_replicas=world_size, shuffle=True) - kwargs = {'batch_size': 16, 'sampler': sampler} + sampler = DistributedSampler(dataset=train_dataset, + rank=rank, + num_replicas=world_size, + shuffle=True) + kwargs = { + 'batch_size': 16, + 'sampler': sampler, + 'num_workers': 4, + 'pin_memory': True + } train_loader = DataLoader(train_dataset, **kwargs) return train_loader + def fsdp_main(rank, world_size, args): print((rank, world_size)) setup(rank, world_size) + my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, + min_num_params=100) torch.cuda.set_device(rank) # load model config = Qwen2Config( @@ -415,7 +402,7 @@ def fsdp_main(rank, world_size, args): model = Qwen2ForCausalLM(config) optimizer = get_optimizer(args.optimizer, model, lr=args.lr, wd=args.wd) model = model.to(rank) - model = DDP(model) + model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) model.train() train_loader = get_train_loader(args.dataset, rank, world_size) @@ -427,25 +414,31 @@ def fsdp_main(rank, world_size, args): num_cycles=0.5, ) - ddp_loss = torch.zeros(2).to(rank) for epoch in range(epoch): - + total_loss = torch.zeros(1, device=rank) for step, batch in enumerate(train_loader): batch = batch.to(rank) input_ids = batch output = model(input_ids=input_ids, labels=input_ids) - optimizer.zero_grad() loss = output.loss loss.backward() + + # Synchronize the loss across all processes + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + avg_loss = loss.item() / world_size # 计算平均损失 + optimizer.step() - ddp_loss[0] += loss.item() - ddp_loss[1] += len(batch) + lr_scheduler.step() + optimizer.zero_grad() - dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + # Log the average loss only on the main process (rank 0) if rank == 0: - logger.info('Train Epoch: {} \tLoss: {:.6f} \t Batch: {}'.format(epoch, ddp_loss[0] / ddp_loss[1], len(batch))) - lr_scheduler.step() - + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {avg_loss}" + ) + # Update total_loss for logging purposes + total_loss += avg_loss + if args.save_model: # use a barrier to make sure training is done on all ranks dist.barrier() @@ -455,14 +448,15 @@ def fsdp_main(rank, world_size, args): cleanup() + if __name__ == '__main__': # Training settings args = parse_args() - # fsdp_main(0, 1, args) + fsdp_main(0, 1, args) - WORLD_SIZE = torch.cuda.device_count() - mp.spawn(fsdp_main, - args=(WORLD_SIZE, args), - nprocs=WORLD_SIZE, - join=True) \ No newline at end of file + # WORLD_SIZE = torch.cuda.device_count() + # mp.spawn(fsdp_main, + # args=(WORLD_SIZE, args), + # nprocs=WORLD_SIZE, + # join=True) diff --git a/examples/toy_train.py b/examples/toy_train.py index a883beb..59c7281 100644 --- a/examples/toy_train.py +++ b/examples/toy_train.py @@ -131,8 +131,6 @@ def __init__( params.extend(adamw_params) super().__init__(params, defaults) # Sort parameters into those for which we will use Muon, and those for which we will not - import pdb - pdb.set_trace() for p in muon_params: # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer assert p.ndim == 2, p.ndim @@ -318,20 +316,7 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): else: assert 0, "optimizer not supported" - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--wd", type=float, default=0.1) - parser.add_argument("--dataset", type=str, default="openwebtext-100k") - parser.add_argument("--hidden_size", type=int, default=1024) - args = parser.parse_args() - logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") - +def main(args): model = get_model(args.model, args.dataset, args.hidden_size) optimizer = get_optimizer( @@ -344,7 +329,9 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): model.train() epoch = 1 train_loader = get_dataloader(args.model, args.dataset, args.hidden_size) - + # 13299 + print('train data length:', len(train_loader)) + lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, @@ -364,3 +351,19 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): logger.info( f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + main(args=args) + \ No newline at end of file diff --git a/examples/toy_train_fsdp.py b/examples/toy_train_fsdp.py new file mode 100644 index 0000000..bd8012e --- /dev/null +++ b/examples/toy_train_fsdp.py @@ -0,0 +1,406 @@ +from loguru import logger +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +import os +import math +import torch +import argparse +import torch.multiprocessing as mp +import functools +from transformers import ( + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Tokenizer, + get_cosine_schedule_with_warmup, +) +from tqdm import tqdm +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +import torch.distributed as dist +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + +class MoonDataset(Dataset): + + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = (b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [ + p for p in group["params"] if not self.state[p]["use_muon"] + ] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model_cpu(model_name, hidden_size): + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model + + +def get_dataloader(model_name, dataset_name, rank, world_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], + trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", + trust_remote_code=True) + else: + assert 0, f"model {model_name} not supported" + + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + sampler = DistributedSampler(dataset=train_dataset, + rank=rank, + num_replicas=world_size, + shuffle=True) + kwargs = { + 'batch_size': 16, + 'sampler': sampler, + 'num_workers': 4, + 'pin_memory': True + } + train_loader = DataLoader(train_dataset, **kwargs) + return train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW(model.parameters(), + lr=lr, + weight_decay=wd, + betas=(0.9, 0.95)) + elif optimizer_name == "muon": + muon_params = [ + p for name, p in model.named_parameters() if p.ndim >= 2 + and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p for name, p in model.named_parameters() + if not (p.ndim >= 2 and "embed_tokens" not in name + and "lm_head" not in name) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12365' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def fsdp_main(rank, world_size, args): + + print((rank, world_size)) + setup(rank, world_size) + model = get_model_cpu(args.model, args.hidden_size) + my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, + min_num_params=100) + torch.cuda.set_device(rank) + model.to(rank) + model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) + model.train() + optimizer = get_optimizer(args.optimizer, model, lr=args.lr) + + epoch = 1 + train_loader = get_dataloader(args.model, args.dataset, rank, world_size) + logger.info(('train data length:', len(train_loader))) + + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(rank) + input_ids = batch + + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + + dist.all_reduce(loss, op=dist.ReduceOp.AVG) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if rank == 0: + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" + ) + cleanup() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + # 2025-04-28 17:28:51.906 | INFO | __mp_main__:fsdp_main:387 - Epoch: 0 Step: 0 LR: 1e-05 Training loss: 12.13530158996582 diff --git a/examples/toy_train_v1.py b/examples/toy_train_v1.py new file mode 100644 index 0000000..6658da9 --- /dev/null +++ b/examples/toy_train_v1.py @@ -0,0 +1,368 @@ +import os +import math +import torch +from loguru import logger +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset +from transformers import ( + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Tokenizer, + get_cosine_schedule_with_warmup, +) +from tqdm import tqdm +from torch.utils.data.distributed import DistributedSampler + +class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model(model_name, dataset_name, hidden_size): + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model + + +def get_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", trust_remote_code=True + ) + else: + assert 0, f"model {model_name} not supported" + + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + sampler = DistributedSampler(dataset=train_dataset, rank=0, num_replicas=1, shuffle=True) + kwargs = {'batch_size': 16, 'sampler': sampler, 'num_workers': 4, 'pin_memory': True} + train_loader = DataLoader(train_dataset, **kwargs) + return train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) + ) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not ( + p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + +def main(args): + model = get_model(args.model, args.dataset, args.hidden_size) + + optimizer = get_optimizer( + args.optimizer, model, lr=args.lr + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + model.train() + epoch = 1 + train_loader = get_dataloader(args.model, args.dataset, args.hidden_size) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(device) + input_ids = batch + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + main(args=args) + \ No newline at end of file diff --git a/examples/toy_train_v2.py b/examples/toy_train_v2.py new file mode 100644 index 0000000..db6a152 --- /dev/null +++ b/examples/toy_train_v2.py @@ -0,0 +1,407 @@ +import os +import math +import torch +from loguru import logger +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from torch.utils.data.distributed import DistributedSampler +import os +import math +import torch +import argparse +import torch.multiprocessing as mp +import torch.nn.functional as F +import functools +from transformers import ( + Qwen2Config, + Qwen2ForCausalLM, + Qwen2Tokenizer, + get_cosine_schedule_with_warmup, +) +from tqdm import tqdm +from typing import Optional +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.tensor import distribute_tensor +from torch.distributed.tensor import DTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +import torch.distributed as dist +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +class MoonDataset(Dataset): + def __init__(self, dataset_name, dataset, tokenizer, max_length=512): + self.dataset_name = dataset_name + self.dataset = dataset + self.tokenizer = tokenizer + self.texts = dataset["train"]["text"] + self.max_length = max_length + self.tokens = [] + self._tokenize_texts() + + def _tokenize_texts(self): + if os.path.exists(f"{self.dataset_name}.bin"): + self.tokens = torch.load(f"{self.dataset_name}.bin") + else: + for text in tqdm(self.texts, desc="Tokenizing texts"): + encoded = self.tokenizer.encode(text, add_special_tokens=True) + self.tokens.extend(encoded) + torch.save(self.tokens, f"{self.dataset_name}.bin") + + def __len__(self): + return len(self.tokens) // self.max_length + + def __getitem__(self, idx): + start_idx = idx * (self.max_length) + end_idx = start_idx + (self.max_length) + token_slice = self.tokens[start_idx:end_idx] + data = torch.tensor(token_slice, dtype=torch.long) + return data + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group['lr'] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss + + +def get_model(model_name, dataset_name, hidden_size): + if model_name == "qwen": + config = Qwen2Config( + attention_dropout=0.0, + bos_token_id=151643, + eos_token_id=151643, + hidden_act="silu", + hidden_size=hidden_size, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=513, + max_window_layers=12, + model_type="qwen2", + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + sliding_window=1024, + tie_word_embeddings=True, + torch_dtype="bfloat16", + use_cache=True, + use_mrope=False, + use_sliding_window=False, + vocab_size=151936, + ) + model = Qwen2ForCausalLM(config) + else: + assert 0, f"model {model_name} not supported" + return model + + +def get_dataloader(model_name, dataset_name, rank, world_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", trust_remote_code=True + ) + else: + assert 0, f"model {model_name} not supported" + + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + sampler = DistributedSampler(dataset=train_dataset, rank=rank, num_replicas=world_size, shuffle=True) + kwargs = {'batch_size': 16, 'sampler': sampler, 'num_workers': 4, 'pin_memory': True} + train_loader = DataLoader(train_dataset, **kwargs) + return train_loader + + +def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): + if optimizer_name == "adamw": + return torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) + ) + elif optimizer_name == "muon": + muon_params = [ + p + for name, p in model.named_parameters() + if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if not ( + p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ) + ] + + return Muon( + lr=lr, + wd=wd, + muon_params=muon_params, + adamw_params=adamw_params, + ) + else: + assert 0, "optimizer not supported" + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def fsdp_main(rank, world_size, args): + print((rank, world_size)) + setup(rank, world_size) + model = get_model(args.model, args.dataset, args.hidden_size) + model = model.to(rank) + my_auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=100 + ) + model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) + model.train() + + optimizer = get_optimizer( + args.optimizer, model, lr=args.lr + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + model.train() + epoch = 1 + train_loader = get_dataloader(args.model, args.dataset, rank, world_size) + + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_loader) * epoch, + num_cycles=0.5, + ) + for epoch in range(epoch): + for step, batch in enumerate(train_loader): + batch = batch.to(device) + input_ids = batch + outputs = model(input_ids=input_ids, labels=input_ids) + loss = outputs.loss + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if rank == 0: + logger.info( + f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="muon") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + fsdp_main(rank=0, world_size=1, args=args) From 051ccc5e22db03af95e694c5b7749094d8c48f68 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 28 Apr 2025 17:49:35 +0800 Subject: [PATCH 3/5] feat(examples): update --- examples/accelerate_configs/zero2.yaml | 21 -- examples/fsdp_main.py | 196 ----------- examples/fsdp_toy_train.py | 462 ------------------------- examples/toy_train_v1.py | 368 -------------------- examples/toy_train_v2.py | 407 ---------------------- 5 files changed, 1454 deletions(-) delete mode 100644 examples/accelerate_configs/zero2.yaml delete mode 100644 examples/fsdp_main.py delete mode 100644 examples/fsdp_toy_train.py delete mode 100644 examples/toy_train_v1.py delete mode 100644 examples/toy_train_v2.py diff --git a/examples/accelerate_configs/zero2.yaml b/examples/accelerate_configs/zero2.yaml deleted file mode 100644 index f4df9d4..0000000 --- a/examples/accelerate_configs/zero2.yaml +++ /dev/null @@ -1,21 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - offload_optimizer_device: none - offload_param_device: none - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/examples/fsdp_main.py b/examples/fsdp_main.py deleted file mode 100644 index 61930ce..0000000 --- a/examples/fsdp_main.py +++ /dev/null @@ -1,196 +0,0 @@ -# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py -import os -import argparse -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms - - -from torch.optim.lr_scheduler import StepLR - -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - CPUOffload, - BackwardPrefetch, -) -from torch.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - enable_wrap, - wrap, -) - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - -def cleanup(): - dist.destroy_process_group() - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - -def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): - model.train() - ddp_loss = torch.zeros(2).to(rank) - if sampler: - sampler.set_epoch(epoch) - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(rank), target.to(rank) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target, reduction='sum') - loss.backward() - optimizer.step() - ddp_loss[0] += loss.item() - ddp_loss[1] += len(data) - - dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) - if rank == 0: - print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) - -def test(model, rank, world_size, test_loader): - model.eval() - correct = 0 - ddp_loss = torch.zeros(3).to(rank) - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(rank), target.to(rank) - output = model(data) - ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() - ddp_loss[2] += len(data) - - dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) - - if rank == 0: - test_loss = ddp_loss[0] / ddp_loss[2] - print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( - test_loss, int(ddp_loss[1]), int(ddp_loss[2]), - 100. * ddp_loss[1] / ddp_loss[2])) - -def fsdp_main(rank, world_size, args): - print((rank, world_size)) - setup(rank, world_size) - - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - - sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) - sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) - - train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} - test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} - cuda_kwargs = {'num_workers': 2, - 'pin_memory': True, - 'shuffle': False} - train_kwargs.update(cuda_kwargs) - test_kwargs.update(cuda_kwargs) - - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) - test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=100 - ) - torch.cuda.set_device(rank) - - - init_start_event = torch.cuda.Event(enable_timing=True) - init_end_event = torch.cuda.Event(enable_timing=True) - - model = Net().to(rank) - model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) - - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) - - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) - init_start_event.record() - for epoch in range(1, args.epochs + 1): - train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) - test(model, rank, world_size, test_loader) - scheduler.step() - - init_end_event.record() - - if rank == 0: - init_end_event.synchronize() - print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec") - print(f"{model}") - - if args.save_model: - # use a barrier to make sure training is done on all ranks - dist.barrier() - states = model.state_dict() - if rank == 0: - torch.save(states, "mnist_cnn.pt") - - cleanup() - -if __name__ == '__main__': - # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - args = parser.parse_args() - - torch.manual_seed(args.seed) - - WORLD_SIZE = torch.cuda.device_count() - mp.spawn(fsdp_main, - args=(WORLD_SIZE, args), - nprocs=WORLD_SIZE, - join=True) \ No newline at end of file diff --git a/examples/fsdp_toy_train.py b/examples/fsdp_toy_train.py deleted file mode 100644 index 1889fd3..0000000 --- a/examples/fsdp_toy_train.py +++ /dev/null @@ -1,462 +0,0 @@ -import os -import math -import torch -from loguru import logger -from datasets import load_dataset -import argparse -import torch.multiprocessing as mp -import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset -import functools -from torch.utils.data.distributed import DistributedSampler -from transformers import ( - Qwen2Config, - Qwen2ForCausalLM, - Qwen2Tokenizer, - get_cosine_schedule_with_warmup, -) -from tqdm import tqdm -from typing import Optional -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed.tensor import distribute_tensor -from torch.distributed.tensor import DTensor -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -import torch.distributed as dist -from torch.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - enable_wrap, - wrap, -) - - -def to_dist(x, from_local=False, **meta): - if from_local: - return DTensor.from_local( - x, - device_mesh=meta["device_mesh"], - placements=meta["placements"], - shape=meta["shape"], - stride=meta["stride"], - ) - else: - return distribute_tensor(x, - device_mesh=meta["device_mesh"], - placements=meta["placements"]) - - -def to_local(x, keep_sharded=False): - if isinstance(x, DTensor): - meta = dict( - device_mesh=x.device_mesh, - placements=x.placements, - shape=x.shape, - stride=x.stride(), - ) - if keep_sharded: - return x.to_local(), meta - else: - return x.full_tensor(), meta - - return x, None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - adamw_wd: The weight decay for the internal AdamW. - """ - - def __init__( - self, - lr=1e-3, - wd=0.1, - muon_params=None, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_params=None, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - ): - - defaults = dict( - lr=lr, - wd=wd, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - ) - - params = list(muon_params) - adamw_params = list(adamw_params) if adamw_params is not None else [] - params.extend(adamw_params) - super().__init__(params, defaults) - # Sort parameters into those for which we will use Muon, and those for which we will not - for p in muon_params: - # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim - self.state[p]["use_muon"] = True - for p in adamw_params: - # Do not use Muon for parameters in adamw_params - self.state[p]["use_muon"] = False - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - ############################ - # Muon # - ############################ - - params = [p for p in group["params"] if self.state[p]["use_muon"]] - # import pdb; pdb.set_trace() - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] - - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - - # scale update - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - - # apply weight decay - p.data.mul_(1 - lr * wd) - - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - ############################ - # AdamW backup # - ############################ - - params = [ - p for p in group["params"] if not self.state[p]["use_muon"] - ] - lr = group['lr'] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) - - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * weight_decay) - p.data.add_(g, alpha=-lr / scale) - - return loss - - -class MoonDataset(Dataset): - - def __init__(self, dataset_name, dataset, tokenizer, max_length=512): - self.dataset_name = dataset_name - self.dataset = dataset - self.tokenizer = tokenizer - self.texts = dataset["train"]["text"] - self.max_length = max_length - self.tokens = [] - self._tokenize_texts() - - def _tokenize_texts(self): - if os.path.exists(f"{self.dataset_name}.bin"): - print('loading tokenized data') - self.tokens = torch.load(f"{self.dataset_name}.bin") - else: - for text in tqdm(self.texts, desc="Tokenizing texts"): - encoded = self.tokenizer.encode(text, add_special_tokens=True) - self.tokens.extend(encoded) - torch.save(self.tokens, f"{self.dataset_name}.bin") - - def __len__(self): - return len(self.tokens) // self.max_length - - def __getitem__(self, idx): - start_idx = idx * (self.max_length) - end_idx = start_idx + (self.max_length) - token_slice = self.tokens[start_idx:end_idx] - data = torch.tensor(token_slice, dtype=torch.long) - return data - - -# This code snippet is a modified version adapted from the following GitHub repository: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -@torch.compile -def zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.T - B = (b * A + c * A @ A - ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(0) > G.size(1): - X = X.T - return X - - -def parse_args(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--wd", type=float, default=0.1) - parser.add_argument("--dataset", type=str, default="openwebtext-100k") - parser.add_argument("--hidden_size", type=int, default=1024) - parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument('--save-model', - action='store_true', - default=False, - help='For Saving the current Model') - return parser.parse_args() - - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): - if optimizer_name == "adamw": - return torch.optim.AdamW(model.parameters(), - lr=lr, - weight_decay=wd, - betas=(0.9, 0.95)) - elif optimizer_name == "muon": - muon_params = [ - p for name, p in model.named_parameters() if p.ndim >= 2 - and "embed_tokens" not in name and "lm_head" not in name - ] - adamw_params = [ - p for name, p in model.named_parameters() - if not (p.ndim >= 2 and "embed_tokens" not in name - and "lm_head" not in name) - ] - - return Muon( - lr=lr, - wd=wd, - muon_params=muon_params, - adamw_params=adamw_params, - ) - else: - assert 0, "optimizer not supported" - - -def get_train_loader(dataset_name, rank, world_size): - name2path = { - "openwebtext-100k": "Elriggs/openwebtext-100k", - } - train_dataset = load_dataset(name2path[dataset_name], - trust_remote_code=True) - tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", - trust_remote_code=True) - train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - sampler = DistributedSampler(dataset=train_dataset, - rank=rank, - num_replicas=world_size, - shuffle=True) - kwargs = { - 'batch_size': 16, - 'sampler': sampler, - 'num_workers': 4, - 'pin_memory': True - } - train_loader = DataLoader(train_dataset, **kwargs) - return train_loader - - -def fsdp_main(rank, world_size, args): - print((rank, world_size)) - setup(rank, world_size) - my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, - min_num_params=100) - torch.cuda.set_device(rank) - # load model - config = Qwen2Config( - attention_dropout=0.0, - bos_token_id=151643, - eos_token_id=151643, - hidden_act="silu", - hidden_size=1024, - initializer_range=0.02, - intermediate_size=4864, - max_position_embeddings=513, - max_window_layers=12, - model_type="qwen2", - num_attention_heads=16, - num_hidden_layers=12, - num_key_value_heads=16, - rms_norm_eps=1e-06, - rope_theta=1000000.0, - sliding_window=1024, - tie_word_embeddings=True, - torch_dtype="bfloat16", - use_cache=True, - use_mrope=False, - use_sliding_window=False, - vocab_size=151936, - ) - model = Qwen2ForCausalLM(config) - optimizer = get_optimizer(args.optimizer, model, lr=args.lr, wd=args.wd) - model = model.to(rank) - model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) - model.train() - - train_loader = get_train_loader(args.dataset, rank, world_size) - epoch = 1 - lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_loader) * epoch, - num_cycles=0.5, - ) - - for epoch in range(epoch): - total_loss = torch.zeros(1, device=rank) - for step, batch in enumerate(train_loader): - batch = batch.to(rank) - input_ids = batch - output = model(input_ids=input_ids, labels=input_ids) - loss = output.loss - loss.backward() - - # Synchronize the loss across all processes - dist.all_reduce(loss, op=dist.ReduceOp.SUM) - avg_loss = loss.item() / world_size # 计算平均损失 - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Log the average loss only on the main process (rank 0) - if rank == 0: - logger.info( - f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {avg_loss}" - ) - # Update total_loss for logging purposes - total_loss += avg_loss - - if args.save_model: - # use a barrier to make sure training is done on all ranks - dist.barrier() - states = model.state_dict() - if rank == 0: - torch.save(states, "mnist_cnn.pt") - - cleanup() - - -if __name__ == '__main__': - # Training settings - args = parse_args() - - fsdp_main(0, 1, args) - - # WORLD_SIZE = torch.cuda.device_count() - # mp.spawn(fsdp_main, - # args=(WORLD_SIZE, args), - # nprocs=WORLD_SIZE, - # join=True) diff --git a/examples/toy_train_v1.py b/examples/toy_train_v1.py deleted file mode 100644 index 6658da9..0000000 --- a/examples/toy_train_v1.py +++ /dev/null @@ -1,368 +0,0 @@ -import os -import math -import torch -from loguru import logger -from datasets import load_dataset -from torch.utils.data import DataLoader, Dataset -from transformers import ( - Qwen2Config, - Qwen2ForCausalLM, - Qwen2Tokenizer, - get_cosine_schedule_with_warmup, -) -from tqdm import tqdm -from torch.utils.data.distributed import DistributedSampler - -class MoonDataset(Dataset): - def __init__(self, dataset_name, dataset, tokenizer, max_length=512): - self.dataset_name = dataset_name - self.dataset = dataset - self.tokenizer = tokenizer - self.texts = dataset["train"]["text"] - self.max_length = max_length - self.tokens = [] - self._tokenize_texts() - - def _tokenize_texts(self): - if os.path.exists(f"{self.dataset_name}.bin"): - self.tokens = torch.load(f"{self.dataset_name}.bin") - else: - for text in tqdm(self.texts, desc="Tokenizing texts"): - encoded = self.tokenizer.encode(text, add_special_tokens=True) - self.tokens.extend(encoded) - torch.save(self.tokens, f"{self.dataset_name}.bin") - - def __len__(self): - return len(self.tokens) // self.max_length - - def __getitem__(self, idx): - start_idx = idx * (self.max_length) - end_idx = start_idx + (self.max_length) - token_slice = self.tokens[start_idx:end_idx] - data = torch.tensor(token_slice, dtype=torch.long) - return data - - -# This code snippet is a modified version adapted from the following GitHub repository: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -@torch.compile -def zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.T - B = ( - b * A + c * A @ A - ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(0) > G.size(1): - X = X.T - return X - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - adamw_wd: The weight decay for the internal AdamW. - """ - - def __init__( - self, - lr=1e-3, - wd=0.1, - muon_params=None, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_params=None, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - ): - - defaults = dict( - lr=lr, - wd=wd, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - ) - - params = list(muon_params) - adamw_params = list(adamw_params) if adamw_params is not None else [] - params.extend(adamw_params) - super().__init__(params, defaults) - # Sort parameters into those for which we will use Muon, and those for which we will not - for p in muon_params: - # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim - self.state[p]["use_muon"] = True - for p in adamw_params: - # Do not use Muon for parameters in adamw_params - self.state[p]["use_muon"] = False - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - ############################ - # Muon # - ############################ - - params = [p for p in group["params"] if self.state[p]["use_muon"]] - # import pdb; pdb.set_trace() - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] - - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - - # scale update - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - - # apply weight decay - p.data.mul_(1 - lr * wd) - - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - ############################ - # AdamW backup # - ############################ - - params = [p for p in group["params"] if not self.state[p]["use_muon"]] - lr = group['lr'] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) - - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * weight_decay) - p.data.add_(g, alpha=-lr / scale) - - return loss - - -def get_model(model_name, dataset_name, hidden_size): - if model_name == "qwen": - config = Qwen2Config( - attention_dropout=0.0, - bos_token_id=151643, - eos_token_id=151643, - hidden_act="silu", - hidden_size=hidden_size, - initializer_range=0.02, - intermediate_size=4864, - max_position_embeddings=513, - max_window_layers=12, - model_type="qwen2", - num_attention_heads=16, - num_hidden_layers=12, - num_key_value_heads=16, - rms_norm_eps=1e-06, - rope_theta=1000000.0, - sliding_window=1024, - tie_word_embeddings=True, - torch_dtype="bfloat16", - use_cache=True, - use_mrope=False, - use_sliding_window=False, - vocab_size=151936, - ) - model = Qwen2ForCausalLM(config) - else: - assert 0, f"model {model_name} not supported" - return model - - -def get_dataloader(model_name, dataset_name, hidden_size): - name2path = { - "openwebtext-100k": "Elriggs/openwebtext-100k", - } - train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) - if model_name == "qwen": - tokenizer = Qwen2Tokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B", trust_remote_code=True - ) - else: - assert 0, f"model {model_name} not supported" - - train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - sampler = DistributedSampler(dataset=train_dataset, rank=0, num_replicas=1, shuffle=True) - kwargs = {'batch_size': 16, 'sampler': sampler, 'num_workers': 4, 'pin_memory': True} - train_loader = DataLoader(train_dataset, **kwargs) - return train_loader - - -def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): - if optimizer_name == "adamw": - return torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) - ) - elif optimizer_name == "muon": - muon_params = [ - p - for name, p in model.named_parameters() - if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ] - adamw_params = [ - p - for name, p in model.named_parameters() - if not ( - p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ) - ] - - return Muon( - lr=lr, - wd=wd, - muon_params=muon_params, - adamw_params=adamw_params, - ) - else: - assert 0, "optimizer not supported" - -def main(args): - model = get_model(args.model, args.dataset, args.hidden_size) - - optimizer = get_optimizer( - args.optimizer, model, lr=args.lr - ) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - model.train() - epoch = 1 - train_loader = get_dataloader(args.model, args.dataset, args.hidden_size) - lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_loader) * epoch, - num_cycles=0.5, - ) - for epoch in range(epoch): - for step, batch in enumerate(train_loader): - batch = batch.to(device) - input_ids = batch - outputs = model(input_ids=input_ids, labels=input_ids) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - logger.info( - f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--wd", type=float, default=0.1) - parser.add_argument("--dataset", type=str, default="openwebtext-100k") - parser.add_argument("--hidden_size", type=int, default=1024) - args = parser.parse_args() - logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") - main(args=args) - \ No newline at end of file diff --git a/examples/toy_train_v2.py b/examples/toy_train_v2.py deleted file mode 100644 index db6a152..0000000 --- a/examples/toy_train_v2.py +++ /dev/null @@ -1,407 +0,0 @@ -import os -import math -import torch -from loguru import logger -from datasets import load_dataset -from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm -from torch.utils.data.distributed import DistributedSampler -import os -import math -import torch -import argparse -import torch.multiprocessing as mp -import torch.nn.functional as F -import functools -from transformers import ( - Qwen2Config, - Qwen2ForCausalLM, - Qwen2Tokenizer, - get_cosine_schedule_with_warmup, -) -from tqdm import tqdm -from typing import Optional -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed.tensor import distribute_tensor -from torch.distributed.tensor import DTensor -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -import torch.distributed as dist -from torch.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - enable_wrap, - wrap, -) - -class MoonDataset(Dataset): - def __init__(self, dataset_name, dataset, tokenizer, max_length=512): - self.dataset_name = dataset_name - self.dataset = dataset - self.tokenizer = tokenizer - self.texts = dataset["train"]["text"] - self.max_length = max_length - self.tokens = [] - self._tokenize_texts() - - def _tokenize_texts(self): - if os.path.exists(f"{self.dataset_name}.bin"): - self.tokens = torch.load(f"{self.dataset_name}.bin") - else: - for text in tqdm(self.texts, desc="Tokenizing texts"): - encoded = self.tokenizer.encode(text, add_special_tokens=True) - self.tokens.extend(encoded) - torch.save(self.tokens, f"{self.dataset_name}.bin") - - def __len__(self): - return len(self.tokens) // self.max_length - - def __getitem__(self, idx): - start_idx = idx * (self.max_length) - end_idx = start_idx + (self.max_length) - token_slice = self.tokens[start_idx:end_idx] - data = torch.tensor(token_slice, dtype=torch.long) - return data - - -# This code snippet is a modified version adapted from the following GitHub repository: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -@torch.compile -def zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.T - B = ( - b * A + c * A @ A - ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(0) > G.size(1): - X = X.T - return X - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - adamw_wd: The weight decay for the internal AdamW. - """ - - def __init__( - self, - lr=1e-3, - wd=0.1, - muon_params=None, - momentum=0.95, - nesterov=True, - ns_steps=5, - adamw_params=None, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - ): - - defaults = dict( - lr=lr, - wd=wd, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - ) - - params = list(muon_params) - adamw_params = list(adamw_params) if adamw_params is not None else [] - params.extend(adamw_params) - super().__init__(params, defaults) - # Sort parameters into those for which we will use Muon, and those for which we will not - for p in muon_params: - # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer - assert p.ndim == 2, p.ndim - self.state[p]["use_muon"] = True - for p in adamw_params: - # Do not use Muon for parameters in adamw_params - self.state[p]["use_muon"] = False - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - ############################ - # Muon # - ############################ - - params = [p for p in group["params"] if self.state[p]["use_muon"]] - # import pdb; pdb.set_trace() - lr = group["lr"] - wd = group["wd"] - momentum = group["momentum"] - - # generate weight updates in distributed fashion - for p in params: - # sanity check - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - - # scale update - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - - # apply weight decay - p.data.mul_(1 - lr * wd) - - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - ############################ - # AdamW backup # - ############################ - - params = [p for p in group["params"] if not self.state[p]["use_muon"]] - lr = group['lr'] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["wd"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - if "step" not in state: - state["step"] = 0 - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - state["step"] += 1 - step = state["step"] - buf1 = state["moment1"] - buf2 = state["moment2"] - buf1.lerp_(g, 1 - beta1) - buf2.lerp_(g.square(), 1 - beta2) - - g = buf1 / (eps + buf2.sqrt()) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * weight_decay) - p.data.add_(g, alpha=-lr / scale) - - return loss - - -def get_model(model_name, dataset_name, hidden_size): - if model_name == "qwen": - config = Qwen2Config( - attention_dropout=0.0, - bos_token_id=151643, - eos_token_id=151643, - hidden_act="silu", - hidden_size=hidden_size, - initializer_range=0.02, - intermediate_size=4864, - max_position_embeddings=513, - max_window_layers=12, - model_type="qwen2", - num_attention_heads=16, - num_hidden_layers=12, - num_key_value_heads=16, - rms_norm_eps=1e-06, - rope_theta=1000000.0, - sliding_window=1024, - tie_word_embeddings=True, - torch_dtype="bfloat16", - use_cache=True, - use_mrope=False, - use_sliding_window=False, - vocab_size=151936, - ) - model = Qwen2ForCausalLM(config) - else: - assert 0, f"model {model_name} not supported" - return model - - -def get_dataloader(model_name, dataset_name, rank, world_size): - name2path = { - "openwebtext-100k": "Elriggs/openwebtext-100k", - } - train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) - if model_name == "qwen": - tokenizer = Qwen2Tokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B", trust_remote_code=True - ) - else: - assert 0, f"model {model_name} not supported" - - train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - sampler = DistributedSampler(dataset=train_dataset, rank=rank, num_replicas=world_size, shuffle=True) - kwargs = {'batch_size': 16, 'sampler': sampler, 'num_workers': 4, 'pin_memory': True} - train_loader = DataLoader(train_dataset, **kwargs) - return train_loader - - -def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): - if optimizer_name == "adamw": - return torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95) - ) - elif optimizer_name == "muon": - muon_params = [ - p - for name, p in model.named_parameters() - if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ] - adamw_params = [ - p - for name, p in model.named_parameters() - if not ( - p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name - ) - ] - - return Muon( - lr=lr, - wd=wd, - muon_params=muon_params, - adamw_params=adamw_params, - ) - else: - assert 0, "optimizer not supported" - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - -def cleanup(): - dist.destroy_process_group() - -def fsdp_main(rank, world_size, args): - print((rank, world_size)) - setup(rank, world_size) - model = get_model(args.model, args.dataset, args.hidden_size) - model = model.to(rank) - my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=100 - ) - model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) - model.train() - - optimizer = get_optimizer( - args.optimizer, model, lr=args.lr - ) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - model.train() - epoch = 1 - train_loader = get_dataloader(args.model, args.dataset, rank, world_size) - - lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_loader) * epoch, - num_cycles=0.5, - ) - for epoch in range(epoch): - for step, batch in enumerate(train_loader): - batch = batch.to(device) - input_ids = batch - outputs = model(input_ids=input_ids, labels=input_ids) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if rank == 0: - logger.info( - f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--wd", type=float, default=0.1) - parser.add_argument("--dataset", type=str, default="openwebtext-100k") - parser.add_argument("--hidden_size", type=int, default=1024) - args = parser.parse_args() - logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") - fsdp_main(rank=0, world_size=1, args=args) From 07930a6e44817b1a46a4c5bc1a69bc31a00db262 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 28 Apr 2025 17:53:54 +0800 Subject: [PATCH 4/5] update(examples): update --- examples/dustbin.py | 179 ------------------------------------- examples/toy_train.py | 76 +++++++--------- examples/toy_train_fsdp.py | 3 - 3 files changed, 33 insertions(+), 225 deletions(-) delete mode 100644 examples/dustbin.py diff --git a/examples/dustbin.py b/examples/dustbin.py deleted file mode 100644 index 61edc09..0000000 --- a/examples/dustbin.py +++ /dev/null @@ -1,179 +0,0 @@ - -# class Muon(torch.optim.Optimizer): -# def __init__( -# self, -# muon_params=None, -# lr=1e-3, -# weight_decay=0.1, -# momentum=0.95, -# nesterov=True, -# ns_steps=5, -# adamw_params=None, -# betas=(0.9, 0.95), -# eps=1e-8, -# *, -# maximize: bool = False, -# foreach: Optional[bool] = None, -# capturable: bool = False, -# differentiable: bool = False, -# fused: Optional[bool] = None, -# bias_correction=True, -# ): -# defaults = dict( -# lr=lr, -# betas=betas, -# eps=eps, -# weight_decay=weight_decay, -# momentum=momentum, -# nesterov=nesterov, -# ns_steps=ns_steps, -# foreach=foreach, -# maximize=maximize, -# capturable=capturable, -# differentiable=differentiable, -# fused=fused, -# bias_correction=bias_correction, -# ) - -# params = [] - -# muon_params = list(muon_params) if muon_params is not None else [] -# params.extend(muon_params) - -# adamw_params = list(adamw_params) if adamw_params is not None else [] -# params.extend(adamw_params) - -# super().__init__(params, defaults) - -# # sort params into those for which we will use muon and those for which we will not -# for p in muon_params: -# # for p in group["params"]: -# assert p.ndim == 2, p.ndim -# self.state[p]["use_muon"] = True -# for p in adamw_params: -# # for p in group["params"]: -# self.state[p]["use_muon"] = False - -# @staticmethod -# def adjust_lr_for_muon(lr, param_shape): -# A, B = param_shape[:2] - -# adjusted_ratio = 0.2 * math.sqrt(max(A, B)) -# adjusted_lr = lr * adjusted_ratio - -# return adjusted_lr - -# @staticmethod -# def _update_adamw( -# data, -# grad, -# exp_avg, -# exp_avg_sq, -# lr, -# beta1, -# beta2, -# eps, -# weight_decay, -# bias_correction1, -# bias_correction2, -# ): -# grad = grad.to(data.dtype) - -# # Decay the first and second moment running average coefficient -# exp_avg.lerp_(grad, 1 - beta1) -# exp_avg_sq.lerp_(grad.square(), 1 - beta2) - -# grad = exp_avg / (eps + exp_avg_sq.sqrt()) - -# scale = bias_correction1 / bias_correction2**0.5 - -# if weight_decay != 0: -# data.mul_(1 - lr * weight_decay) - -# data.add_(grad, alpha=-lr / scale) - -# @torch.no_grad() -# def step(self, closure=None, **kwargs): -# loss = None -# if closure is not None: -# with torch.enable_grad(): -# loss = closure() - -# for group in self.param_groups: -# params = [p for p in group["params"] if self.state[p]["use_muon"]] - -# for p in params: -# g = p.grad -# if g is None: -# continue -# if g.ndim > 2: -# g = g.view(g.size(0), -1) -# assert g is not None - -# # calc update -# state = self.state[p] - -# if "momentum_buffer" not in state: -# state["momentum_buffer"] = torch.zeros_like(g) -# buf = state["momentum_buffer"] -# buf.mul_(group["momentum"]).add_(g) -# g = g.add(buf, alpha=group["momentum"]) if group["nesterov"] else buf - -# meta = None -# if isinstance(g, DTensor): -# g, meta = to_local(g, keep_sharded=False) - -# # gives NaNs when done with DTensor, instead of throwing a typical op not supported error, quite sneaky -# g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) - -# if meta is not None: -# g = to_dist(g, **meta) - -# g *= max(1, g.size(0) / g.size(1)) ** 0.5 -# g = g.view_as(p.data).type_as(p.data) - -# # apply weight decay -# if group["weight_decay"] != 0: -# p.data.mul_(1 - group["lr"] * group["weight_decay"]) - -# # apply lr and update -# adjusted_lr = self.adjust_lr_for_muon(group["lr"], p.shape) -# p.data.add_(g, alpha=-adjusted_lr) - -# # adamw -# params = [p for p in group["params"] if not self.state[p]["use_muon"]] -# beta1, beta2 = group["betas"] - -# for p in params: -# g = p.grad -# if g is None: -# continue - -# state = self.state[p] - -# if "step" not in state: -# state["step"] = 0 -# # gradient momentums -# state["exp_avg"] = torch.zeros_like(p, device=p.device) -# # gradient variances -# state["exp_avg_sq"] = torch.zeros_like(p, device=p.device) - -# state["step"] += 1 - -# bias_correction1 = 1 - beta1 ** state["step"] -# bias_correction2 = 1 - beta2 ** state["step"] - -# self._update_adamw( -# p.data, -# p.grad.data, -# state["exp_avg"], -# state["exp_avg_sq"], -# group["lr"], -# beta1, -# beta2, -# group["eps"], -# group["weight_decay"], -# bias_correction1, -# bias_correction2, -# ) -# return loss \ No newline at end of file diff --git a/examples/toy_train.py b/examples/toy_train.py index 59c7281..9b9ab35 100644 --- a/examples/toy_train.py +++ b/examples/toy_train.py @@ -239,7 +239,20 @@ def step(self, closure=None): return loss -def get_model(model_name, dataset_name, hidden_size): +def get_model_and_dataloader(model_name, dataset_name, hidden_size): + name2path = { + "openwebtext-100k": "Elriggs/openwebtext-100k", + } + train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) + if model_name == "qwen": + tokenizer = Qwen2Tokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", trust_remote_code=True + ) + else: + assert 0, f"model {model_name} not supported" + train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + if model_name == "qwen": config = Qwen2Config( attention_dropout=0.0, @@ -268,24 +281,7 @@ def get_model(model_name, dataset_name, hidden_size): model = Qwen2ForCausalLM(config) else: assert 0, f"model {model_name} not supported" - return model - - -def get_dataloader(model_name, dataset_name, hidden_size): - name2path = { - "openwebtext-100k": "Elriggs/openwebtext-100k", - } - train_dataset = load_dataset(name2path[dataset_name], trust_remote_code=True) - if model_name == "qwen": - tokenizer = Qwen2Tokenizer.from_pretrained( - "Qwen/Qwen2.5-0.5B", trust_remote_code=True - ) - else: - assert 0, f"model {model_name} not supported" - train_dataset = MoonDataset(dataset_name, train_dataset, tokenizer) - train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) - - return train_loader + return model, train_loader def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): @@ -316,9 +312,23 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): else: assert 0, "optimizer not supported" -def main(args): - model = get_model(args.model, args.dataset, args.hidden_size) - + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="qwen") + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--wd", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="openwebtext-100k") + parser.add_argument("--hidden_size", type=int, default=1024) + args = parser.parse_args() + logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") + + model, train_loader = get_model_and_dataloader( + args.model, args.dataset, args.hidden_size + ) optimizer = get_optimizer( args.optimizer, model, lr=args.lr ) @@ -328,10 +338,6 @@ def main(args): model.train() epoch = 1 - train_loader = get_dataloader(args.model, args.dataset, args.hidden_size) - # 13299 - print('train data length:', len(train_loader)) - lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, @@ -350,20 +356,4 @@ def main(args): optimizer.zero_grad() logger.info( f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="qwen") - parser.add_argument("--optimizer", type=str, default="muon") - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--wd", type=float, default=0.1) - parser.add_argument("--dataset", type=str, default="openwebtext-100k") - parser.add_argument("--hidden_size", type=int, default=1024) - args = parser.parse_args() - logger.add(f"logs/train_{args.model}_{args.optimizer}_lr{args.lr}.log") - main(args=args) - \ No newline at end of file + ) \ No newline at end of file diff --git a/examples/toy_train_fsdp.py b/examples/toy_train_fsdp.py index bd8012e..5e273d7 100644 --- a/examples/toy_train_fsdp.py +++ b/examples/toy_train_fsdp.py @@ -391,8 +391,6 @@ def fsdp_main(rank, world_size, args): if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="qwen") parser.add_argument("--optimizer", type=str, default="muon") @@ -403,4 +401,3 @@ def fsdp_main(rank, world_size, args): args = parser.parse_args() WORLD_SIZE = torch.cuda.device_count() mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) - # 2025-04-28 17:28:51.906 | INFO | __mp_main__:fsdp_main:387 - Epoch: 0 Step: 0 LR: 1e-05 Training loss: 12.13530158996582 From cf66dc337e969b4e64027029c778943e84e86f7e Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Mon, 28 Apr 2025 17:55:38 +0800 Subject: [PATCH 5/5] typo(examples): update --- examples/toy_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/toy_train.py b/examples/toy_train.py index 9b9ab35..fa1e339 100644 --- a/examples/toy_train.py +++ b/examples/toy_train.py @@ -356,4 +356,4 @@ def get_optimizer(optimizer_name, model, lr=1e-3, wd=0.1): optimizer.zero_grad() logger.info( f"Epoch: {epoch} Step: {step} LR: {optimizer.param_groups[0]['lr']} Training loss: {loss.item()}" - ) \ No newline at end of file + )