From 7cfafe0b932f17d71386075ee21c05b2adf0fc01 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 11 Oct 2025 10:43:00 +0200 Subject: [PATCH 1/4] 270m support --- maester/models/gemma/__init__.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/maester/models/gemma/__init__.py b/maester/models/gemma/__init__.py index ddd0233..eabc420 100644 --- a/maester/models/gemma/__init__.py +++ b/maester/models/gemma/__init__.py @@ -3,6 +3,25 @@ __all__ = ["GemmaTextModel", "ModelArgs"] gemma3_configs = { + "270M": ModelArgs( + vocab_size=262_144, + dim=640, + n_layers=18, + n_heads=4, + num_key_value_heads=1, + head_dim=256, + intermediate_size=2048, + attn_types=["local_sliding", "local_sliding", "local_sliding", "local_sliding", "local_sliding", "global"], + use_post_ffw_norm=True, + use_pre_ffw_norm=True, + sliding_window_size=512, + rope_wave_length={ + "local_sliding": 10_000, + "global": 1_000_000, + }, + use_qk_norm=True, + vision_config=None, + ), "1B": ModelArgs( vocab_size=262_144, # Actual size from google/gemma-3-1b-pt tokenizer dim=1152, From ac3be4923fb1c98980f884230f36fdb049f80d02 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sun, 12 Oct 2025 01:54:28 +0200 Subject: [PATCH 2/4] working DiffPriv for non-sharded models --- maester/config.py | 7 ++ train.py | 243 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 209 insertions(+), 41 deletions(-) diff --git a/maester/config.py b/maester/config.py index 5524b80..6d2fa9f 100644 --- a/maester/config.py +++ b/maester/config.py @@ -157,6 +157,13 @@ class Config(BaseSettings): enable_async_tensor_parallel: bool = False enable_compiled_autograd: bool = True + # differential privacy + dp_enabled: bool = False + dp_clip_norm: float = 1.0 + dp_noise_multiplier: float = 1.0 + dp_num_privacy_units: int = 1 + dp_delta: float = 1e-6 + # profiling enable_profiling: bool = True enable_memory_snapshot: bool = False diff --git a/train.py b/train.py index bc641c0..4bfd74b 100644 --- a/train.py +++ b/train.py @@ -32,6 +32,7 @@ from maester.memory import cleanup_before_training from maester.metrics import build_gpu_memory_monitor, build_metric_logger, register_logits_monitoring, WeightScaleMonitor from maester.data_monitor import DataMonitor +from maester.dp_privacy import DPConfig, DPSanitizer, SimplePLDAccountant, no_grad_sync_for_fsdp from maester.models import ( model_name_to_cls, models_config, @@ -297,6 +298,63 @@ def loss_fn(pred, labels): if cfg.compile: loss_fn = torch.compile(loss_fn) + if cfg.dp_enabled: + # GOOD: pick groups explicitly from your 3-D mesh + # === DP groups (replace your current block that sets dp_pg/mp_pg) === + dp_repl_pg = None + if cfg.dp_enabled and parallel_dims.dp_replicate_enabled: + try: + dp_repl_pg = world_mesh["dp_replicate"].get_group() + except KeyError: + # Case: only one DP dim named "dp" because dp_shard==1 + try: + dp_repl_pg = world_mesh["dp"].get_group() + except KeyError: + dp_repl_pg = None # single-replica fallback + + def _maybe_get_pg(*dims): + try: + return world_mesh[dims].get_group() + except KeyError: + return None + + mp_pg = None + if cfg.dp_enabled: + # Prefer joint TP×DP_shard if both present + if parallel_dims.tp_enabled and parallel_dims.dp_shard_enabled: + mp_pg = _maybe_get_pg("tp", "dp_shard") + # Fall back to TP only + if mp_pg is None and parallel_dims.tp_enabled: + mp_pg = _maybe_get_pg("tp") + # Or DP_shard only + if mp_pg is None and parallel_dims.dp_shard_enabled: + mp_pg = _maybe_get_pg("dp_shard") + # Else mp_pg stays None (no param sharding) + + dp_cfg = DPConfig(C=cfg.dp_clip_norm, sigma=cfg.dp_noise_multiplier) + sanitizer = DPSanitizer(model, dp_pg=dp_repl_pg, mp_pg=mp_pg, cfg=dp_cfg) + + # --- loss_fn that produces per-sample losses --- + def per_sample_losses(logits, labels, ignore_index=-100): + # logits: [B, T, V], labels: [B, T] + # CE per token + loss_tok = F.cross_entropy( + logits.flatten(0,1).float(), labels.flatten(0,1), + reduction="none", ignore_index=ignore_index + ).view(labels.shape) # [B, T] + # mask padding + valid = (labels != ignore_index).float() + # sum over tokens -> per-sample scalar + loss_per_sample = (loss_tok * valid).sum(dim=1) # [B] + return loss_per_sample + + if cfg.compile: + per_sample_losses = torch.compile(per_sample_losses, dynamic=True) + + delta = cfg.dp_delta # e.g., 1.0 / N_priv + pld_acc = SimplePLDAccountant(delta=delta) # FFT-based PLD for Poisson subsampled Gaussian + pld_ready = True + # training loop cleanup_before_training() model.train() @@ -354,15 +412,9 @@ def loss_fn(pred, labels): input_ids = batch["input_ids"] labels = batch["labels"] - - # Get position_ids if available (currently only from packed SFT data) - # TODO: Consider generating position_ids for all data loaders for consistency position_ids = batch.get("position_ids", None) - - # Get document_ids if available (for flex attention document masking in packed data) document_ids = batch.get("document_ids", None) - # Collect padding stats if available (SFT mode) if "stats" in batch and "actual_lengths" in batch["stats"]: padding_lengths_since_last_log.append(batch["stats"]["actual_lengths"]) @@ -376,45 +428,129 @@ def loss_fn(pred, labels): if document_ids is not None: document_ids = document_ids.cuda() + # Preserve your sync policy for accumulation sync_grads_now = True if skip_sync_during_accum: - sync_grads_now = micro_idx == grad_accum_steps - 1 - + sync_grads_now = (micro_idx == grad_accum_steps - 1) if fsdp_can_toggle_sync and grad_accum_steps > 1: model.set_requires_gradient_sync(sync_grads_now) - with loss_parallel_ctx(): - if cfg.enable_cut_cross_entropy: - loss = model( - input_ids, - labels, - position_ids=position_ids, - document_ids=document_ids, + # === Branch on DP === + if not cfg.dp_enabled: + # ---------- ORIGINAL NON-DP PATH (unchanged) ---------- + with loss_parallel_ctx(): + if cfg.enable_cut_cross_entropy: + loss = model( + input_ids, + labels, + position_ids=position_ids, + document_ids=document_ids, + ) + else: + pred = model( + input_ids, + position_ids=position_ids, + document_ids=document_ids, + ) + loss = loss_fn(pred, labels) + del pred + + losses_since_last_log.append(loss.detach()) + (loss / grad_accum_steps).backward() + + if (fsdp_can_toggle_sync and grad_accum_steps > 1 and + skip_sync_during_accum and not sync_grads_now): + model.set_requires_gradient_sync(True) + + else: + # ---------- DP PATH ---------- + # Micro-batch size on this rank + E_local = input_ids.shape[0] + + # 1) PASS 1: collect per-sample squared norms via hooks + sanitizer.begin_microstep(E_local) + with loss_parallel_ctx(): + # forward + logits = model( + input_ids, position_ids=position_ids, document_ids=document_ids ) - else: - pred = model( - input_ids, - position_ids=position_ids, - document_ids=document_ids, + loss_i = per_sample_losses(logits, labels) # [E_local] + del logits + + # backward on sum to populate ghost norms; avoid grad sync & param grad all-reduce + with no_grad_sync_for_fsdp(model): + loss_i.sum().backward() + + scales = sanitizer.end_collect_and_compute_scales() # [E_local] + + # IMPORTANT: wipe param grads from pass 1; we only want pass-2 grads + for p in model.parameters(): + if p.grad is not None: + p.grad = None + + # 2) PASS 2: recompute, backprop clipped mean + with loss_parallel_ctx(): + logits = model( + input_ids, position_ids=position_ids, document_ids=document_ids ) - loss = loss_fn(pred, labels) - del pred - - losses_since_last_log.append(loss.detach()) - scaled_loss = loss / grad_accum_steps - scaled_loss.backward() - - if ( - fsdp_can_toggle_sync - and grad_accum_steps > 1 - and skip_sync_during_accum - and not sync_grads_now - ): - model.set_requires_gradient_sync(True) - - grad_norms = clip_grad_norm( # note: maester.utils.clip_grad_norm, not torch.nn.utils.clip_grad_norm_ - model.parameters(), cfg.max_grad_norm, foreach=True - ) + loss_i = per_sample_losses(logits, labels) # [E_local] + assert loss_i.requires_grad, "loss_i lost its graph before DP backprop." + + # For logging parity, record token-mean "loss" like your original + # (sum over tokens per sample divided by number of valid tokens) + # We approximate via batch mean of per-sample sums divided by seq len of valid tokens on this rank. + # For stability (and same units as before), log the average per-sample sum / tokens_per_sample_mean. + with torch.no_grad(): + valid = (labels != -100).float() + denom = valid.sum().clamp_min(1.0) + avg_loss_like = loss_i.sum() / denom + losses_since_last_log.append(avg_loss_like.detach()) + + # Sum across dp_replicate (replicas see different examples) + if parallel_dims.dp_replicate_enabled and dp_repl_pg is not None: + E_repl = torch.tensor([E_local], device=input_ids.device, dtype=torch.int64) + dist.all_reduce(E_repl, op=dist.ReduceOp.SUM, group=dp_repl_pg) + E_repl = int(E_repl.item()) + else: + E_repl = E_local + + # Multiply by dp_shard size only if that dim exists AND shards see distinct samples + dp_shard_factor = 1 + if parallel_dims.dp_shard_enabled: + try: + dp_shard_factor = world_mesh["dp_shard"].size() + except KeyError: + dp_shard_factor = 1 # no such dim in this layout + + E_global_micro = E_repl * dp_shard_factor + + # Accumulate per-step total batch (to scale noise ONCE after all microsteps) + if micro_idx == 0: + E_global_accum = E_global_micro + else: + E_global_accum += E_global_micro + + # Backprop clipped mean for this microstep; grads accumulate across microsteps + sanitizer.backprop_clipped_mean(loss_i, scales, E_global=E_global_micro) + + if (fsdp_can_toggle_sync and grad_accum_steps > 1 and + skip_sync_during_accum and not sync_grads_now): + model.set_requires_gradient_sync(True) + + # === End of microstep accumulation === + if not cfg.dp_enabled: + # Original clip + step + grad_norms = clip_grad_norm(model.parameters(), cfg.max_grad_norm, foreach=True) + else: + # DP: no extra clipping here. Optionally compute diagnostic norms WITHOUT clipping: + grad_list = [p.grad for p in model.parameters() if (p.grad is not None)] + grad_norms = [] + if len(grad_list) > 0: + # foreach_norm is cheap; purely for metrics (matches your logging keys) + grad_norms = torch._foreach_norm(grad_list, 2) + # Add identical noise across DP replicas ONCE per step; scale by total E_global over microsteps + sanitizer.add_dp_noise_(optimizer, E_global=E_global_accum, step=train_state.step) + optimizer.step() scheduler.step() train_state.step += 1 @@ -511,9 +647,16 @@ def loss_fn(pred, labels): }) for i in range(len(optimizer.param_groups)): metrics[f"lr/group{i}"] = scheduler.get_last_lr()[i] - for gn, (name, _) in zip(grad_norms, model.named_parameters()): - cn = clean_param_name(name) - metrics[f"{cn}/grad_norm"] = gn + if not cfg.dp_enabled: + # unchanged + for gn, (name, _) in zip(grad_norms, model.named_parameters()): + cn = clean_param_name(name); metrics[f"{cn}/grad_norm"] = gn + else: + # align names with the grads we actually normed + named_with_grad = [(name, p) for name, p in model.named_parameters() if p.grad is not None] + for (name, _), gn in zip(named_with_grad, grad_norms): + cn = clean_param_name(name) + metrics[f"{cn}/grad_norm"] = gn for exp_avg_norm, exp_avg_sq_norm, name in zip(exp_avg_norms, exp_avg_sq_norms, param_names): cn = clean_param_name(name) metrics[f"{cn}/exp_avg_norm"] = exp_avg_norm @@ -526,6 +669,24 @@ def loss_fn(pred, labels): # metrics.update(get_logits_metrics()) if weight_scale_stats: metrics.update(weight_scale_stats) + if cfg.dp_enabled: + # local + clip_frac_local = float((scales < 1).float().mean().item()) + metrics["dp/clip_frac_local"] = clip_frac_local + # global mean over replicas + if parallel_dims.dp_replicate_enabled and dp_pg is not None: + t = torch.tensor([clip_frac_local], device=input_ids.device, dtype=torch.float32) + dist.all_reduce(t, op=dist.ReduceOp.SUM, group=dp_repl_pg) + t /= dist.get_world_size(dp_pg) + metrics["dp/clip_frac"] = float(t.item()) + metrics["dp/C"] = cfg.dp_clip_norm + metrics["dp/sigma"] = cfg.dp_noise_multiplier + metrics["dp/E_global_step"] = float(E_global_accum) + N_priv = cfg.dp_num_privacy_units # total #examples in *private* set + q_t = float(E_global_accum) / float(N_priv) + pld_acc.add_step(q=q_t, sigma=cfg.dp_noise_multiplier) + metrics["dp/q"] = q_t + metrics["dp/eps@delta"] = pld_acc.epsilon() if pld_ready else float("nan") if metric_logger is not None: metric_logger.log(metrics, step=train_state.step) From a9976676acca5cdb9146a5eeb70f667d03f37add Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sun, 12 Oct 2025 03:09:57 +0200 Subject: [PATCH 3/4] make Comma7B work with FSDP --- maester/config.py | 1 + train.py | 135 +++++++++++++++++++++++++++++----------------- 2 files changed, 87 insertions(+), 49 deletions(-) diff --git a/maester/config.py b/maester/config.py index 6d2fa9f..0c3f8e0 100644 --- a/maester/config.py +++ b/maester/config.py @@ -163,6 +163,7 @@ class Config(BaseSettings): dp_noise_multiplier: float = 1.0 dp_num_privacy_units: int = 1 dp_delta: float = 1e-6 + dp_assert: bool = False # profiling enable_profiling: bool = True diff --git a/train.py b/train.py index 4bfd74b..e3eb7fb 100644 --- a/train.py +++ b/train.py @@ -301,35 +301,28 @@ def loss_fn(pred, labels): if cfg.dp_enabled: # GOOD: pick groups explicitly from your 3-D mesh # === DP groups (replace your current block that sets dp_pg/mp_pg) === + # --- Process groups --- dp_repl_pg = None - if cfg.dp_enabled and parallel_dims.dp_replicate_enabled: + try: + dp_repl_pg = world_mesh["dp_replicate"].get_group() + except Exception: try: - dp_repl_pg = world_mesh["dp_replicate"].get_group() - except KeyError: - # Case: only one DP dim named "dp" because dp_shard==1 - try: - dp_repl_pg = world_mesh["dp"].get_group() - except KeyError: - dp_repl_pg = None # single-replica fallback - - def _maybe_get_pg(*dims): - try: - return world_mesh[dims].get_group() - except KeyError: - return None + dp_repl_pg = world_mesh["dp"].get_group() + except Exception: + dp_repl_pg = None # single replica mp_pg = None - if cfg.dp_enabled: - # Prefer joint TP×DP_shard if both present - if parallel_dims.tp_enabled and parallel_dims.dp_shard_enabled: - mp_pg = _maybe_get_pg("tp", "dp_shard") - # Fall back to TP only - if mp_pg is None and parallel_dims.tp_enabled: - mp_pg = _maybe_get_pg("tp") - # Or DP_shard only - if mp_pg is None and parallel_dims.dp_shard_enabled: - mp_pg = _maybe_get_pg("dp_shard") - # Else mp_pg stays None (no param sharding) + try: + # Prefer shard×TP if you have both; else TP; else shard; else None + mp_pg = world_mesh["tp","dp_shard"].get_group() + except Exception: + try: + mp_pg = world_mesh["tp"].get_group() + except Exception: + try: + mp_pg = world_mesh["dp_shard"].get_group() + except Exception: + mp_pg = None dp_cfg = DPConfig(C=cfg.dp_clip_norm, sigma=cfg.dp_noise_multiplier) sanitizer = DPSanitizer(model, dp_pg=dp_repl_pg, mp_pg=mp_pg, cfg=dp_cfg) @@ -455,7 +448,7 @@ def per_sample_losses(logits, labels, ignore_index=-100): loss = loss_fn(pred, labels) del pred - losses_since_last_log.append(loss.detach()) + losses_since_last_log.append(float(loss.detach())) (loss / grad_accum_steps).backward() if (fsdp_can_toggle_sync and grad_accum_steps > 1 and @@ -504,7 +497,7 @@ def per_sample_losses(logits, labels, ignore_index=-100): valid = (labels != -100).float() denom = valid.sum().clamp_min(1.0) avg_loss_like = loss_i.sum() / denom - losses_since_last_log.append(avg_loss_like.detach()) + losses_since_last_log.append(float(avg_loss_like.detach())) # Sum across dp_replicate (replicas see different examples) if parallel_dims.dp_replicate_enabled and dp_repl_pg is not None: @@ -548,9 +541,46 @@ def per_sample_losses(logits, labels, ignore_index=-100): if len(grad_list) > 0: # foreach_norm is cheap; purely for metrics (matches your logging keys) grad_norms = torch._foreach_norm(grad_list, 2) - # Add identical noise across DP replicas ONCE per step; scale by total E_global over microsteps + grad_norms = [float(t.item()) for t in grad_norms] # after foreach_norm + # (A) Sum/avg grads across DP replicas – handle DTensor grads safely + if dp_repl_pg is not None and dist.get_world_size(dp_repl_pg) > 1: + world = dist.get_world_size(dp_repl_pg) + with torch.no_grad(): + for p in model.parameters(): + g = getattr(p, "grad", None) + if g is None: + continue + # IMPORTANT: don't call dist.all_reduce on a DTensor – operate on its local shard. + if hasattr(torch.distributed.tensor, "DTensor") and isinstance(g, torch.distributed.tensor.DTensor): + local = g.to_local() # regular Tensor on this rank + dist.all_reduce(local, op=dist.ReduceOp.SUM, group=dp_repl_pg) + local.div_(world) + else: + dist.all_reduce(g, op=dist.ReduceOp.SUM, group=dp_repl_pg) + g.div_(world) + if cfg.dp_assert: + def _pick_big_param_with_grad(model): + best = None + best_n = -1 + for p in model.parameters(): + if p.grad is None: + continue + n = p.numel() + if n > best_n: + best, best_n = p, n + return best, best_n + # (B) Add identical Gaussian noise (scale by total E_global across microsteps) + # BEFORE sanitizer.add_dp_noise_ + p_probe, _ = _pick_big_param_with_grad(model) + g_before_probe = (p_probe.grad.to_local() if hasattr(p_probe.grad, "to_local") else p_probe.grad).detach().clone() sanitizer.add_dp_noise_(optimizer, E_global=E_global_accum, step=train_state.step) - + if cfg.dp_assert: + # AFTER sanitizer.add_dp_noise_ + g_after_probe = (p_probe.grad.to_local() if hasattr(p_probe.grad, "to_local") else p_probe.grad) + delta = (g_after_probe - g_before_probe).float() + est_std = float(delta.view(-1)[:262144].std().item()) + expected = (cfg.dp_noise_multiplier * cfg.dp_clip_norm) / float(E_global_accum) + assert est_std > 0 and abs(est_std/expected - 1) <= 0.3, "Same-step DP noise std off." optimizer.step() scheduler.step() train_state.step += 1 @@ -562,7 +592,7 @@ def per_sample_losses(logits, labels, ignore_index=-100): # log metrics if train_state.step == 1 or train_state.step % cfg.log_freq == 0: - losses = [l.detach().item() for l in losses_since_last_log] + losses = losses_since_last_log[:] avg_loss, max_loss = ( np.mean(losses), np.max(losses), @@ -579,15 +609,19 @@ def per_sample_losses(logits, labels, ignore_index=-100): exp_avgs, exp_avg_sqs, param_names = [], [], [] for group in optimizer.param_groups: for p in group['params']: - if p.grad is None: - continue state = optimizer.state[p] - if 'exp_avg' in state: # Check if states initialized - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - param_names.append(param_to_name[p]) - exp_avg_norms = torch._foreach_norm(exp_avgs, 2) - exp_avg_sq_norms = torch._foreach_norm(exp_avg_sqs, 2) + if not state: + continue + ea = state.get('exp_avg', None) + es = state.get('exp_avg_sq', None) + if ea is not None and es is not None: + exp_avgs.append(ea); exp_avg_sqs.append(es) + param_names.append(param_to_name.get(p, "")) + if exp_avgs: + exp_avg_norms = [float(t.item()) for t in torch._foreach_norm(exp_avgs, 2)] + exp_avg_sq_norms = [float(t.item()) for t in torch._foreach_norm(exp_avg_sqs, 2)] + else: + exp_avg_norms, exp_avg_sq_norms = [], [] time_delta = timer() - time_last_log @@ -670,23 +704,26 @@ def per_sample_losses(logits, labels, ignore_index=-100): if weight_scale_stats: metrics.update(weight_scale_stats) if cfg.dp_enabled: - # local - clip_frac_local = float((scales < 1).float().mean().item()) - metrics["dp/clip_frac_local"] = clip_frac_local - # global mean over replicas - if parallel_dims.dp_replicate_enabled and dp_pg is not None: - t = torch.tensor([clip_frac_local], device=input_ids.device, dtype=torch.float32) - dist.all_reduce(t, op=dist.ReduceOp.SUM, group=dp_repl_pg) - t /= dist.get_world_size(dp_pg) - metrics["dp/clip_frac"] = float(t.item()) metrics["dp/C"] = cfg.dp_clip_norm metrics["dp/sigma"] = cfg.dp_noise_multiplier metrics["dp/E_global_step"] = float(E_global_accum) - N_priv = cfg.dp_num_privacy_units # total #examples in *private* set + + # local clip fraction + metrics["dp/clip_frac_local"] = float((scales < 1).float().mean().item()) + + # global clip fraction (average over DP replicas) + if dp_repl_pg is not None and dist.get_world_size(dp_repl_pg) > 1: + t = torch.tensor([metrics["dp/clip_frac_local"]], device=input_ids.device, dtype=torch.float32) + dist.all_reduce(t, op=dist.ReduceOp.SUM, group=dp_repl_pg) + t /= dist.get_world_size(dp_repl_pg) + metrics["dp/clip_frac"] = float(t.item()) + + # privacy accounting + N_priv = cfg.dp_num_privacy_units q_t = float(E_global_accum) / float(N_priv) pld_acc.add_step(q=q_t, sigma=cfg.dp_noise_multiplier) metrics["dp/q"] = q_t - metrics["dp/eps@delta"] = pld_acc.epsilon() if pld_ready else float("nan") + metrics["dp/eps@delta"] = pld_acc.epsilon() if metric_logger is not None: metric_logger.log(metrics, step=train_state.step) From f4b6628491e484b3a85e189781ea519591f606f7 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sun, 12 Oct 2025 10:38:45 +0200 Subject: [PATCH 4/4] config --- configs/llama/munin-7b-core-pt-dp.toml | 49 ++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 configs/llama/munin-7b-core-pt-dp.toml diff --git a/configs/llama/munin-7b-core-pt-dp.toml b/configs/llama/munin-7b-core-pt-dp.toml new file mode 100644 index 0000000..6836279 --- /dev/null +++ b/configs/llama/munin-7b-core-pt-dp.toml @@ -0,0 +1,49 @@ +model_name = "llama3" +flavor = "Comma7B" +tokenizer_name = "common-pile/comma-v0.1-2t" + +# job +job_name = "munin-7b-core-pt-dp" +wandb_project = "munin-7b-core-pt-dp" +enable_wandb = false + +# parallelism +num_nodes = 1 +data_parallel_shard_degree = 8 +data_parallel_replicate_degree = 1 + +# training settings +train_batch_size = 4 +seq_len = 4096 +gradient_accumulation_steps = 2 +train_num_steps = 60097 +scheduler = "linear_warmup_constant_sqrt_decay" +warmup_steps = 1000 +cooldown_steps = 1000 +checkpoint_interval = 500 +forced_load_path = "/work/training/maester/comma-v0.1-2t-dcp/" +compile = true +enable_cut_cross_entropy = false +ac_mode = "none" +selective_ac_option = "op" + +dp_enabled = true +dp_clip_norm = 1.0 +dp_noise_multiplier = 1.0 +dp_num_privacy_units = 10000 +dp_delta = 1e-6 + +[dataset] +bos_token = 2 +eos_token = 1 +data_dirs = [ + "/work/data/dfm-common-pile-16_3/", +] +dataset_weights = "1.0" + +[opt_cfg] # must specify *all* fields here, will not merge with defaults +lr = 1e-5 +betas = [0.9, 0.95] +weight_decay = 0.1 +eps = 1e-9 +fused = true