diff --git a/configs/llama/munin-7b-core-pt-dp.toml b/configs/llama/munin-7b-core-pt-dp.toml new file mode 100644 index 0000000..6836279 --- /dev/null +++ b/configs/llama/munin-7b-core-pt-dp.toml @@ -0,0 +1,49 @@ +model_name = "llama3" +flavor = "Comma7B" +tokenizer_name = "common-pile/comma-v0.1-2t" + +# job +job_name = "munin-7b-core-pt-dp" +wandb_project = "munin-7b-core-pt-dp" +enable_wandb = false + +# parallelism +num_nodes = 1 +data_parallel_shard_degree = 8 +data_parallel_replicate_degree = 1 + +# training settings +train_batch_size = 4 +seq_len = 4096 +gradient_accumulation_steps = 2 +train_num_steps = 60097 +scheduler = "linear_warmup_constant_sqrt_decay" +warmup_steps = 1000 +cooldown_steps = 1000 +checkpoint_interval = 500 +forced_load_path = "/work/training/maester/comma-v0.1-2t-dcp/" +compile = true +enable_cut_cross_entropy = false +ac_mode = "none" +selective_ac_option = "op" + +dp_enabled = true +dp_clip_norm = 1.0 +dp_noise_multiplier = 1.0 +dp_num_privacy_units = 10000 +dp_delta = 1e-6 + +[dataset] +bos_token = 2 +eos_token = 1 +data_dirs = [ + "/work/data/dfm-common-pile-16_3/", +] +dataset_weights = "1.0" + +[opt_cfg] # must specify *all* fields here, will not merge with defaults +lr = 1e-5 +betas = [0.9, 0.95] +weight_decay = 0.1 +eps = 1e-9 +fused = true diff --git a/maester/config.py b/maester/config.py index 5524b80..0c3f8e0 100644 --- a/maester/config.py +++ b/maester/config.py @@ -157,6 +157,14 @@ class Config(BaseSettings): enable_async_tensor_parallel: bool = False enable_compiled_autograd: bool = True + # differential privacy + dp_enabled: bool = False + dp_clip_norm: float = 1.0 + dp_noise_multiplier: float = 1.0 + dp_num_privacy_units: int = 1 + dp_delta: float = 1e-6 + dp_assert: bool = False + # profiling enable_profiling: bool = True enable_memory_snapshot: bool = False diff --git a/maester/models/gemma/__init__.py b/maester/models/gemma/__init__.py index ddd0233..eabc420 100644 --- a/maester/models/gemma/__init__.py +++ b/maester/models/gemma/__init__.py @@ -3,6 +3,25 @@ __all__ = ["GemmaTextModel", "ModelArgs"] gemma3_configs = { + "270M": ModelArgs( + vocab_size=262_144, + dim=640, + n_layers=18, + n_heads=4, + num_key_value_heads=1, + head_dim=256, + intermediate_size=2048, + attn_types=["local_sliding", "local_sliding", "local_sliding", "local_sliding", "local_sliding", "global"], + use_post_ffw_norm=True, + use_pre_ffw_norm=True, + sliding_window_size=512, + rope_wave_length={ + "local_sliding": 10_000, + "global": 1_000_000, + }, + use_qk_norm=True, + vision_config=None, + ), "1B": ModelArgs( vocab_size=262_144, # Actual size from google/gemma-3-1b-pt tokenizer dim=1152, diff --git a/train.py b/train.py index bc641c0..e3eb7fb 100644 --- a/train.py +++ b/train.py @@ -32,6 +32,7 @@ from maester.memory import cleanup_before_training from maester.metrics import build_gpu_memory_monitor, build_metric_logger, register_logits_monitoring, WeightScaleMonitor from maester.data_monitor import DataMonitor +from maester.dp_privacy import DPConfig, DPSanitizer, SimplePLDAccountant, no_grad_sync_for_fsdp from maester.models import ( model_name_to_cls, models_config, @@ -297,6 +298,56 @@ def loss_fn(pred, labels): if cfg.compile: loss_fn = torch.compile(loss_fn) + if cfg.dp_enabled: + # GOOD: pick groups explicitly from your 3-D mesh + # === DP groups (replace your current block that sets dp_pg/mp_pg) === + # --- Process groups --- + dp_repl_pg = None + try: + dp_repl_pg = world_mesh["dp_replicate"].get_group() + except Exception: + try: + dp_repl_pg = world_mesh["dp"].get_group() + except Exception: + dp_repl_pg = None # single replica + + mp_pg = None + try: + # Prefer shard×TP if you have both; else TP; else shard; else None + mp_pg = world_mesh["tp","dp_shard"].get_group() + except Exception: + try: + mp_pg = world_mesh["tp"].get_group() + except Exception: + try: + mp_pg = world_mesh["dp_shard"].get_group() + except Exception: + mp_pg = None + + dp_cfg = DPConfig(C=cfg.dp_clip_norm, sigma=cfg.dp_noise_multiplier) + sanitizer = DPSanitizer(model, dp_pg=dp_repl_pg, mp_pg=mp_pg, cfg=dp_cfg) + + # --- loss_fn that produces per-sample losses --- + def per_sample_losses(logits, labels, ignore_index=-100): + # logits: [B, T, V], labels: [B, T] + # CE per token + loss_tok = F.cross_entropy( + logits.flatten(0,1).float(), labels.flatten(0,1), + reduction="none", ignore_index=ignore_index + ).view(labels.shape) # [B, T] + # mask padding + valid = (labels != ignore_index).float() + # sum over tokens -> per-sample scalar + loss_per_sample = (loss_tok * valid).sum(dim=1) # [B] + return loss_per_sample + + if cfg.compile: + per_sample_losses = torch.compile(per_sample_losses, dynamic=True) + + delta = cfg.dp_delta # e.g., 1.0 / N_priv + pld_acc = SimplePLDAccountant(delta=delta) # FFT-based PLD for Poisson subsampled Gaussian + pld_ready = True + # training loop cleanup_before_training() model.train() @@ -354,15 +405,9 @@ def loss_fn(pred, labels): input_ids = batch["input_ids"] labels = batch["labels"] - - # Get position_ids if available (currently only from packed SFT data) - # TODO: Consider generating position_ids for all data loaders for consistency position_ids = batch.get("position_ids", None) - - # Get document_ids if available (for flex attention document masking in packed data) document_ids = batch.get("document_ids", None) - # Collect padding stats if available (SFT mode) if "stats" in batch and "actual_lengths" in batch["stats"]: padding_lengths_since_last_log.append(batch["stats"]["actual_lengths"]) @@ -376,45 +421,166 @@ def loss_fn(pred, labels): if document_ids is not None: document_ids = document_ids.cuda() + # Preserve your sync policy for accumulation sync_grads_now = True if skip_sync_during_accum: - sync_grads_now = micro_idx == grad_accum_steps - 1 - + sync_grads_now = (micro_idx == grad_accum_steps - 1) if fsdp_can_toggle_sync and grad_accum_steps > 1: model.set_requires_gradient_sync(sync_grads_now) - with loss_parallel_ctx(): - if cfg.enable_cut_cross_entropy: - loss = model( - input_ids, - labels, - position_ids=position_ids, - document_ids=document_ids, + # === Branch on DP === + if not cfg.dp_enabled: + # ---------- ORIGINAL NON-DP PATH (unchanged) ---------- + with loss_parallel_ctx(): + if cfg.enable_cut_cross_entropy: + loss = model( + input_ids, + labels, + position_ids=position_ids, + document_ids=document_ids, + ) + else: + pred = model( + input_ids, + position_ids=position_ids, + document_ids=document_ids, + ) + loss = loss_fn(pred, labels) + del pred + + losses_since_last_log.append(float(loss.detach())) + (loss / grad_accum_steps).backward() + + if (fsdp_can_toggle_sync and grad_accum_steps > 1 and + skip_sync_during_accum and not sync_grads_now): + model.set_requires_gradient_sync(True) + + else: + # ---------- DP PATH ---------- + # Micro-batch size on this rank + E_local = input_ids.shape[0] + + # 1) PASS 1: collect per-sample squared norms via hooks + sanitizer.begin_microstep(E_local) + with loss_parallel_ctx(): + # forward + logits = model( + input_ids, position_ids=position_ids, document_ids=document_ids ) - else: - pred = model( - input_ids, - position_ids=position_ids, - document_ids=document_ids, + loss_i = per_sample_losses(logits, labels) # [E_local] + del logits + + # backward on sum to populate ghost norms; avoid grad sync & param grad all-reduce + with no_grad_sync_for_fsdp(model): + loss_i.sum().backward() + + scales = sanitizer.end_collect_and_compute_scales() # [E_local] + + # IMPORTANT: wipe param grads from pass 1; we only want pass-2 grads + for p in model.parameters(): + if p.grad is not None: + p.grad = None + + # 2) PASS 2: recompute, backprop clipped mean + with loss_parallel_ctx(): + logits = model( + input_ids, position_ids=position_ids, document_ids=document_ids ) - loss = loss_fn(pred, labels) - del pred - - losses_since_last_log.append(loss.detach()) - scaled_loss = loss / grad_accum_steps - scaled_loss.backward() - - if ( - fsdp_can_toggle_sync - and grad_accum_steps > 1 - and skip_sync_during_accum - and not sync_grads_now - ): - model.set_requires_gradient_sync(True) - - grad_norms = clip_grad_norm( # note: maester.utils.clip_grad_norm, not torch.nn.utils.clip_grad_norm_ - model.parameters(), cfg.max_grad_norm, foreach=True - ) + loss_i = per_sample_losses(logits, labels) # [E_local] + assert loss_i.requires_grad, "loss_i lost its graph before DP backprop." + + # For logging parity, record token-mean "loss" like your original + # (sum over tokens per sample divided by number of valid tokens) + # We approximate via batch mean of per-sample sums divided by seq len of valid tokens on this rank. + # For stability (and same units as before), log the average per-sample sum / tokens_per_sample_mean. + with torch.no_grad(): + valid = (labels != -100).float() + denom = valid.sum().clamp_min(1.0) + avg_loss_like = loss_i.sum() / denom + losses_since_last_log.append(float(avg_loss_like.detach())) + + # Sum across dp_replicate (replicas see different examples) + if parallel_dims.dp_replicate_enabled and dp_repl_pg is not None: + E_repl = torch.tensor([E_local], device=input_ids.device, dtype=torch.int64) + dist.all_reduce(E_repl, op=dist.ReduceOp.SUM, group=dp_repl_pg) + E_repl = int(E_repl.item()) + else: + E_repl = E_local + + # Multiply by dp_shard size only if that dim exists AND shards see distinct samples + dp_shard_factor = 1 + if parallel_dims.dp_shard_enabled: + try: + dp_shard_factor = world_mesh["dp_shard"].size() + except KeyError: + dp_shard_factor = 1 # no such dim in this layout + + E_global_micro = E_repl * dp_shard_factor + + # Accumulate per-step total batch (to scale noise ONCE after all microsteps) + if micro_idx == 0: + E_global_accum = E_global_micro + else: + E_global_accum += E_global_micro + + # Backprop clipped mean for this microstep; grads accumulate across microsteps + sanitizer.backprop_clipped_mean(loss_i, scales, E_global=E_global_micro) + + if (fsdp_can_toggle_sync and grad_accum_steps > 1 and + skip_sync_during_accum and not sync_grads_now): + model.set_requires_gradient_sync(True) + + # === End of microstep accumulation === + if not cfg.dp_enabled: + # Original clip + step + grad_norms = clip_grad_norm(model.parameters(), cfg.max_grad_norm, foreach=True) + else: + # DP: no extra clipping here. Optionally compute diagnostic norms WITHOUT clipping: + grad_list = [p.grad for p in model.parameters() if (p.grad is not None)] + grad_norms = [] + if len(grad_list) > 0: + # foreach_norm is cheap; purely for metrics (matches your logging keys) + grad_norms = torch._foreach_norm(grad_list, 2) + grad_norms = [float(t.item()) for t in grad_norms] # after foreach_norm + # (A) Sum/avg grads across DP replicas – handle DTensor grads safely + if dp_repl_pg is not None and dist.get_world_size(dp_repl_pg) > 1: + world = dist.get_world_size(dp_repl_pg) + with torch.no_grad(): + for p in model.parameters(): + g = getattr(p, "grad", None) + if g is None: + continue + # IMPORTANT: don't call dist.all_reduce on a DTensor – operate on its local shard. + if hasattr(torch.distributed.tensor, "DTensor") and isinstance(g, torch.distributed.tensor.DTensor): + local = g.to_local() # regular Tensor on this rank + dist.all_reduce(local, op=dist.ReduceOp.SUM, group=dp_repl_pg) + local.div_(world) + else: + dist.all_reduce(g, op=dist.ReduceOp.SUM, group=dp_repl_pg) + g.div_(world) + if cfg.dp_assert: + def _pick_big_param_with_grad(model): + best = None + best_n = -1 + for p in model.parameters(): + if p.grad is None: + continue + n = p.numel() + if n > best_n: + best, best_n = p, n + return best, best_n + # (B) Add identical Gaussian noise (scale by total E_global across microsteps) + # BEFORE sanitizer.add_dp_noise_ + p_probe, _ = _pick_big_param_with_grad(model) + g_before_probe = (p_probe.grad.to_local() if hasattr(p_probe.grad, "to_local") else p_probe.grad).detach().clone() + sanitizer.add_dp_noise_(optimizer, E_global=E_global_accum, step=train_state.step) + if cfg.dp_assert: + # AFTER sanitizer.add_dp_noise_ + g_after_probe = (p_probe.grad.to_local() if hasattr(p_probe.grad, "to_local") else p_probe.grad) + delta = (g_after_probe - g_before_probe).float() + est_std = float(delta.view(-1)[:262144].std().item()) + expected = (cfg.dp_noise_multiplier * cfg.dp_clip_norm) / float(E_global_accum) + assert est_std > 0 and abs(est_std/expected - 1) <= 0.3, "Same-step DP noise std off." optimizer.step() scheduler.step() train_state.step += 1 @@ -426,7 +592,7 @@ def loss_fn(pred, labels): # log metrics if train_state.step == 1 or train_state.step % cfg.log_freq == 0: - losses = [l.detach().item() for l in losses_since_last_log] + losses = losses_since_last_log[:] avg_loss, max_loss = ( np.mean(losses), np.max(losses), @@ -443,15 +609,19 @@ def loss_fn(pred, labels): exp_avgs, exp_avg_sqs, param_names = [], [], [] for group in optimizer.param_groups: for p in group['params']: - if p.grad is None: - continue state = optimizer.state[p] - if 'exp_avg' in state: # Check if states initialized - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - param_names.append(param_to_name[p]) - exp_avg_norms = torch._foreach_norm(exp_avgs, 2) - exp_avg_sq_norms = torch._foreach_norm(exp_avg_sqs, 2) + if not state: + continue + ea = state.get('exp_avg', None) + es = state.get('exp_avg_sq', None) + if ea is not None and es is not None: + exp_avgs.append(ea); exp_avg_sqs.append(es) + param_names.append(param_to_name.get(p, "")) + if exp_avgs: + exp_avg_norms = [float(t.item()) for t in torch._foreach_norm(exp_avgs, 2)] + exp_avg_sq_norms = [float(t.item()) for t in torch._foreach_norm(exp_avg_sqs, 2)] + else: + exp_avg_norms, exp_avg_sq_norms = [], [] time_delta = timer() - time_last_log @@ -511,9 +681,16 @@ def loss_fn(pred, labels): }) for i in range(len(optimizer.param_groups)): metrics[f"lr/group{i}"] = scheduler.get_last_lr()[i] - for gn, (name, _) in zip(grad_norms, model.named_parameters()): - cn = clean_param_name(name) - metrics[f"{cn}/grad_norm"] = gn + if not cfg.dp_enabled: + # unchanged + for gn, (name, _) in zip(grad_norms, model.named_parameters()): + cn = clean_param_name(name); metrics[f"{cn}/grad_norm"] = gn + else: + # align names with the grads we actually normed + named_with_grad = [(name, p) for name, p in model.named_parameters() if p.grad is not None] + for (name, _), gn in zip(named_with_grad, grad_norms): + cn = clean_param_name(name) + metrics[f"{cn}/grad_norm"] = gn for exp_avg_norm, exp_avg_sq_norm, name in zip(exp_avg_norms, exp_avg_sq_norms, param_names): cn = clean_param_name(name) metrics[f"{cn}/exp_avg_norm"] = exp_avg_norm @@ -526,6 +703,27 @@ def loss_fn(pred, labels): # metrics.update(get_logits_metrics()) if weight_scale_stats: metrics.update(weight_scale_stats) + if cfg.dp_enabled: + metrics["dp/C"] = cfg.dp_clip_norm + metrics["dp/sigma"] = cfg.dp_noise_multiplier + metrics["dp/E_global_step"] = float(E_global_accum) + + # local clip fraction + metrics["dp/clip_frac_local"] = float((scales < 1).float().mean().item()) + + # global clip fraction (average over DP replicas) + if dp_repl_pg is not None and dist.get_world_size(dp_repl_pg) > 1: + t = torch.tensor([metrics["dp/clip_frac_local"]], device=input_ids.device, dtype=torch.float32) + dist.all_reduce(t, op=dist.ReduceOp.SUM, group=dp_repl_pg) + t /= dist.get_world_size(dp_repl_pg) + metrics["dp/clip_frac"] = float(t.item()) + + # privacy accounting + N_priv = cfg.dp_num_privacy_units + q_t = float(E_global_accum) / float(N_priv) + pld_acc.add_step(q=q_t, sigma=cfg.dp_noise_multiplier) + metrics["dp/q"] = q_t + metrics["dp/eps@delta"] = pld_acc.epsilon() if metric_logger is not None: metric_logger.log(metrics, step=train_state.step)