From 9a57ae2c77798ba817e7dbd0128cfceb8abed1e2 Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Thu, 13 Nov 2025 10:03:15 +0100 Subject: [PATCH 1/4] CP --- maester/config.py | 1 + maester/models/gemma/model.py | 25 ++++++---- maester/models/glm4/model.py | 2 +- maester/models/llama/model.py | 4 +- maester/parallelisms/parallel_dims.py | 61 ++++++++++++++--------- maester/parallelisms/parallelize_gemma.py | 8 ++- maester/parallelisms/parallelize_llama.py | 39 +++++++-------- train.py | 49 ++++++++++++++---- 8 files changed, 117 insertions(+), 72 deletions(-) diff --git a/maester/config.py b/maester/config.py index 38d1068..b739062 100644 --- a/maester/config.py +++ b/maester/config.py @@ -82,6 +82,7 @@ class Config(BaseSettings): data_parallel_shard_degree: int = 8 data_parallel_replicate_degree: int = 1 tensor_parallel_degree: int = 1 + context_parallel_degree: int = 1 expert_parallel_degree: int = 1 train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch gradient_accumulation_steps: int = 1 diff --git a/maester/models/gemma/model.py b/maester/models/gemma/model.py index c140372..a7e92d8 100644 --- a/maester/models/gemma/model.py +++ b/maester/models/gemma/model.py @@ -6,6 +6,7 @@ from torch import nn from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import flex_attention as _flex_attention +from torch.distributed import DeviceMesh from maester.log_utils import logger @@ -285,7 +286,8 @@ class GemmaAttention(nn.Module): def __init__( self, config: ModelArgs, - attn_type: str + attn_type: str, + cp_device_mesh: DeviceMesh | None ): super().__init__() @@ -448,13 +450,15 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, config: ModelArgs, - attn_type: str + attn_type: str, + cp_device_mesh: DeviceMesh | None ): super().__init__() self.attn_type = attn_type self.self_attn = GemmaAttention( config=config, - attn_type=attn_type + attn_type=attn_type, + cp_device_mesh=cp_device_mesh ) self.mlp = GemmaMLP( hidden_size=config.dim, @@ -523,7 +527,7 @@ def init_weights(self, init_std: float): self.mlp.init_weights(init_std) class GemmaModel(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -535,7 +539,7 @@ def __init__(self, config: ModelArgs): if config.attn_types is not None else "global" ) - self.layers.append(Gemma2DecoderLayer(config, attn_type)) + self.layers.append(Gemma2DecoderLayer(config, attn_type, cp_device_mesh)) self.norm = RMSNorm(config.dim, eps=config.rms_norm_eps) def forward( @@ -569,13 +573,14 @@ def init_weights(self, init_std: float): class GemmaTextModel(nn.Module): """Text-only Gemma model compatible with training setup.""" - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None = None): super().__init__() self.config = config self.model_args = config # For compatibility with training code self.vocab_size = config.vocab_size self.n_layers = config.n_layers - + + self.cp_device_mesh = cp_device_mesh # Text embeddings self.tok_embeddings = Embedding( num_embeddings=config.vocab_size, @@ -583,7 +588,7 @@ def __init__(self, config: ModelArgs): ) # Core transformer model - self.model = GemmaModel(config) + self.model = GemmaModel(config, cp_device_mesh=cp_device_mesh) # Precompute RoPE frequencies following multimodal pattern head_dim = config.head_dim @@ -772,9 +777,9 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "GemmaTextModel": + def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "GemmaTextModel": """Initialize from model args (compatible with training loop).""" - return cls(model_args) + return cls(model_args, cp_device_mesh=cp_device_mesh) class Gemma3MultiModalModel(nn.Module): diff --git a/maester/models/glm4/model.py b/maester/models/glm4/model.py index 5acee0d..82b3f8c 100644 --- a/maester/models/glm4/model.py +++ b/maester/models/glm4/model.py @@ -604,7 +604,7 @@ def _process_hidden_states( return hidden_states @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Glm4MoeTextModel": + def from_model_args(cls, model_args: ModelArgs, cp_device_mesh=None) -> "Glm4MoeTextModel": """Initialize from model args (compatible with training loop).""" return cls(model_args) diff --git a/maester/models/llama/model.py b/maester/models/llama/model.py index 90b3d63..acca71f 100644 --- a/maester/models/llama/model.py +++ b/maester/models/llama/model.py @@ -15,6 +15,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.device_mesh import DeviceMesh from maester.models.llama.tied_linear import TiedLinear from maester.models.norms import create_norm @@ -528,12 +529,13 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "Transformer": """ Initialize a Transformer model from a ModelArgs object. Args: model_args (ModelArgs): Model configuration arguments. + cp_device_mesh (Optional[DeviceMesh]): Device mesh for context parallelism. Returns: Transformer: Transformer model. diff --git a/maester/parallelisms/parallel_dims.py b/maester/parallelisms/parallel_dims.py index 18c77b3..58cae8d 100644 --- a/maester/parallelisms/parallel_dims.py +++ b/maester/parallelisms/parallel_dims.py @@ -11,6 +11,7 @@ class ParallelDims: dp_replicate: int dp_shard: int + cp: int tp: int # cp: int # TODO: implement context parallelism ep: int @@ -23,27 +24,27 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, tp, ep = self.dp_replicate, self.dp_shard, self.tp, self.ep - for d in (dp_replicate, tp, ep): + dp_replicate, dp_shard, tp, ep, cp = self.dp_replicate, self.dp_shard, self.tp, self.ep, self.cp + for d in (dp_replicate, tp, ep, cp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." dp = dp_replicate * dp_shard if dp < 0: dp = self.world_size // (tp) - self.dp_shard = dp_shard = dp // dp_replicate + self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp) assert dp_replicate >= 1 assert dp_shard >= 1 assert tp >= 1, tp - assert dp_replicate * dp_shard * tp == self.world_size, ( + assert cp >= 1, cp + assert dp_replicate * dp_shard * tp * cp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"tp({tp}) != WORLD_SIZE({self.world_size})" + f"tp({tp}) * cp({cp}) != WORLD_SIZE({self.world_size})" ) if ep > 1: - #assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - assert ep % tp == 0 and (dp_shard * tp) % ep == 0 + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 def build_mesh(self): if self.ep > 1: @@ -56,8 +57,8 @@ def _build_mesh_without_ep(self) -> DeviceMesh: dims = [] names = [] for d, name in zip( - [self.dp_replicate, self.dp_shard, self.tp], - ["dp_replicate", "dp_shard_cp", "tp"], + [self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["dp_replicate", "dp_shard_cp", "cp", "tp"], ): if d > 1: dims.append(d) @@ -84,9 +85,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh: dp_mesh_dim_names.append("dp_shard_cp") dp_shard_cp_mesh_dim_names.append("dp_shard_cp") dp_cp_mesh_dim_names.append("dp_shard_cp") - # if self.cp_enabled: - # dp_shard_cp_mesh_dim_names.append("cp") - # dp_cp_mesh_dim_names.append("cp") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") if dp_mesh_dim_names != []: mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") @@ -103,9 +104,8 @@ def _build_mesh_with_ep(self) -> DeviceMesh: # With ep, dp_shard and ep are derived submeshes: # dp_shard = dp_shard_mod_ep * dp_shard_in_ep # ep = dp_shard_in_ep * cp - # NOTE: cp not implemented - dp_shard_mod_ep = self.dp_shard * self.tp // self.ep - dp_shard_in_ep = self.ep // self.tp + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) dims = [] names = [] @@ -114,9 +114,10 @@ def _build_mesh_with_ep(self) -> DeviceMesh: self.dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, + self.cp, self.tp, ], - ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "tp"], + ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], ): # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping # helps the MoE layers do mixed precision training @@ -150,10 +151,10 @@ def _build_mesh_with_ep(self) -> DeviceMesh: dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") dp_cp_mesh_dim_names.append("dp_shard_in_ep") ep_mesh_dim_names.append("dp_shard_in_ep") - # if self.cp_enabled: - # dp_shard_cp_mesh_dim_names.append("cp") - # dp_cp_mesh_dim_names.append("cp") - # ep_mesh_dim_names.append("cp") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + ep_mesh_dim_names.append("cp") mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") @@ -176,6 +177,18 @@ def dp_replicate_enabled(self): def dp_shard_enabled(self): return self.dp_shard > 1 + @property + def dp_cp_enabled(self): + return self.dp_enabled or self.cp_enabled + + @property + def cp_enabled(self): + return self.cp > 1 + + @property + def fsdp_enabled(self): + return self.dp_shard_enabled or self.cp_enabled + @property def tp_enabled(self): return self.tp > 1 @@ -184,10 +197,6 @@ def tp_enabled(self): def ep_enabled(self): return self.ep > 1 - @property - def fsdp_enabled(self): - return self.dp_shard_enabled - @property def loss_parallel_enabled(self): return self.tp > 1 and self.enable_loss_parallel @@ -211,6 +220,10 @@ def world_mesh(self) -> DeviceMesh: @cached_property def model_parallel_size(self): return self.tp + + @cached_property + def non_data_parallel_size(self): + return self.cp * self.tp @cached_property def dense_params_mesh_ndim(self): diff --git a/maester/parallelisms/parallelize_gemma.py b/maester/parallelisms/parallelize_gemma.py index 3e0b366..8ecce3d 100644 --- a/maester/parallelisms/parallelize_gemma.py +++ b/maester/parallelisms/parallelize_gemma.py @@ -53,7 +53,7 @@ def parallelize_gemma( # Compile each layer individually if config.compile: - apply_compile(model) + apply_compile(model, fullgraph=not parallel_dims.cp_enabled) # TODO: fullgraph for CP? # Apply FSDP use_fsdp = parallel_dims.dp_enabled or ( @@ -69,7 +69,6 @@ def parallelize_gemma( dp_mesh, param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce], - tp_enabled=parallel_dims.tp_enabled, #pp_enabled=parallel_dims.pp_enabled, ) @@ -257,10 +256,10 @@ def apply_ac(model: nn.Module, config: Config): logger.info("Applied activation checkpointing to the model") -def apply_compile(model: nn.Module): +def apply_compile(model: nn.Module, fullgraph: bool = False): """Compile each transformer layer individually.""" for layer_id, layer in enumerate(model.model.layers): - compiled_layer = torch.compile(layer, fullgraph=True) + compiled_layer = torch.compile(layer, fullgraph=fullgraph) model.model.layers[layer_id] = compiled_layer logger.info("Compiled each transformer layer with torch.compile") @@ -270,7 +269,6 @@ def apply_fsdp( dp_mesh: DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, - tp_enabled: bool, pp_enabled: bool = False, ): """Apply FSDP to Gemma model.""" diff --git a/maester/parallelisms/parallelize_llama.py b/maester/parallelisms/parallelize_llama.py index 41dc9d8..94766e9 100644 --- a/maester/parallelisms/parallelize_llama.py +++ b/maester/parallelisms/parallelize_llama.py @@ -60,28 +60,25 @@ def parallelize_llama( "fused_rmsnorm is not compatible with torch.compile yet. " "Please use rmsnorm or layernorm." ) - apply_compile(model) + apply_compile(model, fullgraph=not parallel_dims.cp_enabled) - use_fsdp = parallel_dims.dp_shard_enabled or ( - world_mesh.ndim == 1 and world_mesh.size() == 1 - ) - if use_fsdp: - if parallel_dims.dp_shard_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh = world_mesh["dp_replicate", "dp_shard"] + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + if parallel_dims.dp_replicate_enabled: + if parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: - dp_mesh = world_mesh["dp"] + dp_mesh_dim_names = ("dp_replicate", "dp_shard") else: - dp_mesh = world_mesh if world_mesh.ndim == 1 else world_mesh["dp"] - - apply_fsdp( - model, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce], - ) - if parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled: + if parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_shard_cp",) + else: + dp_mesh_dim_names = ("dp",) + + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + apply_fsdp(model, dp_mesh, param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce]) + if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") else: logger.info("Applied FSDP to the model") @@ -250,16 +247,16 @@ def apply_ac(model: nn.Module, ac_config: Config): logger.info(f"Applied {ac_config.ac_mode} activation checkpointing to the model") -def apply_compile(model: nn.Module): +def apply_compile(model: nn.Module, fullgraph: bool = True): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) + transformer_block = torch.compile(transformer_block, fullgraph=fullgraph) model.layers.register_module(layer_id, transformer_block) - logger.info("Compiling each TransformerBlock with torch.compile") + logger.info(f"Compiling each TransformerBlock with torch.compile (fullgraph={fullgraph})") def apply_fsdp( diff --git a/train.py b/train.py index 1a8ce09..5a56f6f 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,8 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import set_rotate_method from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -98,6 +100,7 @@ def main(): dp_shard=cfg.data_parallel_shard_degree, dp_replicate=cfg.data_parallel_replicate_degree, tp=cfg.tensor_parallel_degree, + cp=cfg.context_parallel_degree, ep=cfg.expert_parallel_degree, world_size=world_size, enable_loss_parallel=cfg.enable_loss_parallel, @@ -126,6 +129,17 @@ def main(): if parallel_dims.dp_enabled: logger.info(f"{dp_mesh=}") + # TODO: Is this still needed? `_set_cp_global_var` is only available in torch 2.9.0 + # if parallel_dims.cp_enabled: # the following is necessary for CP w/ flex attention + # from torch.distributed.tensor.experimental._attention import _set_cp_global_var, _DispatchMode, _cp_options + + # # set_rotate_method("alltoall") # alltoall or allgather (only allgather for flex) + # _set_cp_global_var("cp_shard_dim", 2) + # # _cp_options.enable_load_balance = True # no load balancing for flex + # torch.distributed.tensor.experimental._attention._dispatch_mode = ( + # _DispatchMode.TORCH_FUNCTION + # ) + # Get tokenizer to determine vocab size if os.path.isfile(cfg.tokenizer_name): tokenizer = PreTrainedTokenizerFast(tokenizer_file=cfg.tokenizer_name) @@ -158,7 +172,7 @@ def main(): logger.info( f"Building {cfg.model_name} {cfg.flavor} with {model_config}" ) - model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config, cp_device_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None) # log model size # model_param_count = get_num_params(model) @@ -331,19 +345,34 @@ def opt_step(): # Get document_ids if available (for flex attention document masking in packed data) document_ids = batch.get("document_ids", None) - # Collect padding stats if available (SFT mode) - if "stats" in batch and "actual_lengths" in batch["stats"]: - padding_lengths_since_last_log.append(batch["stats"]["actual_lengths"]) - - ntokens_since_last_log += labels.numel() - data_loading_times.append(timer() - data_load_start) - input_ids = input_ids.cuda() labels = labels.cuda() if position_ids is not None: position_ids = position_ids.cuda() if document_ids is not None: document_ids = document_ids.cuda() + + buffers = [input_ids, labels] + buffer_seq_dims = [1, 1] # shard on seq dim + if hasattr(model, 'freqs_cis'): + buffers.extend([model.freqs_cis]) + buffer_seq_dims.extend([0]) + elif hasattr(model, 'local_freqs_cis') and hasattr(model, 'global_freqs_cis'): + buffers.extend([model.local_freqs_cis, model.global_freqs_cis]) + buffer_seq_dims.extend([0, 0]) + context_parallel_ctx = context_parallel( + world_mesh["cp"], + buffers=buffers, + buffer_seq_dims=buffer_seq_dims, + no_restore_buffers={input_ids, labels}, # don't restore + ) if parallel_dims.cp_enabled else contextlib.nullcontext() + + # Collect padding stats if available (SFT mode) + if "stats" in batch and "actual_lengths" in batch["stats"]: + padding_lengths_since_last_log.append(batch["stats"]["actual_lengths"]) + + ntokens_since_last_log += labels.numel() + data_loading_times.append(timer() - data_load_start) sync_grads_now = True if skip_sync_during_accum: @@ -352,7 +381,7 @@ def opt_step(): if fsdp_can_toggle_sync and grad_accum_steps > 1: model.set_requires_gradient_sync(sync_grads_now) - with loss_parallel_ctx(): + with loss_parallel_ctx(), context_parallel_ctx: if cfg.enable_cut_cross_entropy: loss = model( input_ids, @@ -428,7 +457,7 @@ def opt_step(): time_delta = timer() - time_last_log total_tokens += ntokens_since_last_log - tps = ntokens_since_last_log / (time_delta * parallel_dims.model_parallel_size) + tps = ntokens_since_last_log / (time_delta * parallel_dims.non_data_parallel_size) mfu = 100 * num_flop_per_token * tps / gpu_peak_flops time_end_to_end = time_delta / cfg.log_freq From 9d50ff902e02d31a0d611b5d77f13244006e8dd1 Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Thu, 13 Nov 2025 10:51:51 +0100 Subject: [PATCH 2/4] Unique name --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 5a56f6f..81a8002 100644 --- a/train.py +++ b/train.py @@ -466,7 +466,7 @@ def opt_step(): # Aggregate data loading times across ALL ranks (TP ranks load redundantly) # Flatten world mesh to get all ranks - global_mesh = world_mesh._flatten() if hasattr(world_mesh, '_flatten') else world_mesh + global_mesh = world_mesh._flatten(mesh_dim_name="global") if hasattr(world_mesh, '_flatten') else world_mesh global_avg_data_loading = dist_mean(time_data_loading, global_mesh).item() global_max_data_loading = dist_max(time_data_loading, global_mesh).item() global_avg_data_loading_pct = dist_mean(time_data_loading_pct, global_mesh).item() From 6485e46b5950423720237b77c557fe85c132680a Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Thu, 13 Nov 2025 10:52:24 +0100 Subject: [PATCH 3/4] Compatibility with train current train code --- maester/models/deepseek/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maester/models/deepseek/model.py b/maester/models/deepseek/model.py index aabb247..ee9f2fb 100644 --- a/maester/models/deepseek/model.py +++ b/maester/models/deepseek/model.py @@ -395,6 +395,6 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: DeepSeekModelArgs) -> "DeepSeekModel": + def from_model_args(cls, model_args: DeepSeekModelArgs, cp_device_mesh=None) -> "DeepSeekModel": """Initialize from model args (compatible with training loop).""" return cls(model_args) \ No newline at end of file From d36078d7f6c5efabfafa833404bc98d97f5f4908 Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Wed, 10 Dec 2025 14:11:31 +0100 Subject: [PATCH 4/4] YaRN implementation New dcp script related to model where yarn has been used to extend the context length --- maester/models/llama/__init__.py | 9 + maester/models/llama/model.py | 154 +++++++- scripts/convert/llama/from_dcp_yarn.py | 317 ++++++++++++++++ scripts/convert/llama/hf_maester_llama.py | 431 ++++++++++++++++++++++ 4 files changed, 898 insertions(+), 13 deletions(-) create mode 100644 scripts/convert/llama/from_dcp_yarn.py create mode 100644 scripts/convert/llama/hf_maester_llama.py diff --git a/maester/models/llama/__init__.py b/maester/models/llama/__init__.py index f7dd965..e73d61a 100644 --- a/maester/models/llama/__init__.py +++ b/maester/models/llama/__init__.py @@ -94,6 +94,15 @@ max_seq_len=4096, vocab_size=64256, ), + "Comma7B-32k": ModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + rope_theta=100000.0, + max_seq_len=32768, + vocab_size=64256, + original_max_context_length=4096, + ), "8B": ModelArgs( dim=4096, n_layers=32, diff --git a/maester/models/llama/model.py b/maester/models/llama/model.py index acca71f..abc2ea9 100644 --- a/maester/models/llama/model.py +++ b/maester/models/llama/model.py @@ -49,6 +49,11 @@ class ModelArgs: mup_output_alpha: float = 1.0 mup_width_mul: float = 1.0 # = width / base_width + # YARN (Yet Another RoPE extensioN) context extension: + # set ``original_max_context_length`` to the model's *old* context length + # when increasing ``max_seq_len`` so YARN RoPE scaling can be applied. + original_max_context_length: Optional[int] = None + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ Calculate the number of parameters and FLOPS per token. @@ -85,29 +90,147 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in return nparams, num_flops_per_token -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis( + dim: int, + max_context_length: int, + theta: float = 10000.0, + device: str = "cuda", + original_max_context_length: Optional[int] = None, + beta_fast: float = 32.0, + beta_slow: float = 1.0, +) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + Supports YaRN (Yet another RoPE extensioN) scaling for context window extension, + following the implementation pattern used in torchtitan / DeepSeek-V3. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + dim (int): Dimension of the frequency tensor (per-head hidden size). + max_context_length (int): Target maximum context length for inference. + theta (float, optional): RoPE base. Defaults to 10000.0. + device (str): Device to create tensors on. Defaults to "cuda". + original_max_context_length (Optional[int]): Original training + context length for YaRN. If None, YaRN is disabled and standard RoPE is used. + beta_fast (float): YaRN hyperparameter controlling the fast-rotating band. + beta_slow (float): YaRN hyperparameter controlling the slow-rotating band. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # End index for precomputing frequencies (typically a small safety + # margin above the maximum context length). + end = max_context_length * 2 + + # Basic RoPE frequency calculation (per torchtitan / DeepSeek-V3 style) + freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32)[ + : (dim // 2) + ] + / dim + ) + ) + + # YaRN scaling for extended context. YaRN is used to extend the context length + # after pre-training. We derive the scaling factor from the ratio between the + # target max context and the original training context. + if ( + original_max_context_length is not None + and max_context_length > original_max_context_length + ): + # seqlen here corresponds to the *target* context window (before the 2x safety margin) + seqlen = max_context_length + base = theta + + # How much we are extending the context window compared to training. + factor = float(seqlen) / float(original_max_context_length) + + # Compute the band of dimensions where we apply the smooth correction, + # using the same helpers as the torchtitan implementation. + low, high = _find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_context_length, + ) + smooth = 1.0 - _linear_ramp_factor(low, high, dim // 2, device) + + # Blend between the down-scaled and original frequencies. + # Outside the [low, high] band, we mostly use the scaled version; inside the + # band, we gradually recover the original frequencies. + freqs = freqs / factor * (1.0 - smooth) + freqs * smooth + + # Positions we will precompute for (may be larger than the actual max context + # to give some safety margin). + t = torch.arange(end, device=device, dtype=torch.float32) + + if ( + original_max_context_length is not None + and max_context_length > original_max_context_length + ): + # Compress the target context range [0, max_seq_len) + # into the original range [0, original_max_context_length). + # Note: we intentionally base the scale factor on the *target* context, + # not on `end`, so that changing the safety margin does not change the + # effective RoPE scaling. + scale_factor = ( + original_max_context_length + / float(max_context_length) + ) + t = t * scale_factor + + freqs_scaled = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs_scaled), freqs_scaled) return freqs_cis +def _find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int +) -> float: + """ + Compute the correction dimension for a given number of rotations + in the rotary positional embedding (YaRN helper). + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + +def _find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int +) -> Tuple[int, int]: + """ + Compute the range of correction dimensions for rotary positional embeddings. + Mirrors torchtitan's YaRN implementation. + """ + low = math.floor(_find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(_find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + +def _linear_ramp_factor( + min_val: float, max_val: float, dim: int, device: str +) -> torch.Tensor: + """ + Linear ramp function used to smoothly blend scaled and unscaled frequencies. + """ + if min_val == max_val: + max_val += 0.001 + linear_func = ( + torch.arange(dim, device=device, dtype=torch.float32) - min_val + ) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. @@ -453,12 +576,17 @@ def init_weights(self): nn.init.normal_(self.output.weight, std=self.model_args.init_std) def _precompute_freqs_cis(self) -> torch.Tensor: + # We always precompute frequencies up to ``max_seq_len``. If + # ``original_max_context_length`` is set and + # ``max_seq_len`` exceeds it, YARN-style scaling is applied inside + # ``precompute_freqs_cis``. + max_context = self.model_args.max_seq_len + return precompute_freqs_cis( - self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # (use 2x max sequence length to be safe) - self.model_args.max_seq_len * 2, - self.model_args.rope_theta, + dim=self.model_args.dim // self.model_args.n_heads, + max_context_length=max_context, + theta=self.model_args.rope_theta, + original_max_context_length=self.model_args.original_max_context_length, ) def forward( diff --git a/scripts/convert/llama/from_dcp_yarn.py b/scripts/convert/llama/from_dcp_yarn.py new file mode 100644 index 0000000..1783613 --- /dev/null +++ b/scripts/convert/llama/from_dcp_yarn.py @@ -0,0 +1,317 @@ +"""Run in singularity container on a machine with enough RAM. For example: +python scripts/convert_dcp_to_hf.py /path/to/checkpoints/ /path/to/output/ \ + --upload danish-foundation-models/munin-7b-{expname} --name step-1000 --base mistralai/Mistral-7B-v0.1 +""" + +import argparse +import json +import os +import re +import shutil +from pathlib import Path + +from typing import Optional + +import torch +from torch.distributed.checkpoint.filesystem import FileSystemReader # type: ignore[attr-defined] +from torch.distributed.checkpoint.metadata import ( # type: ignore[attr-defined] + Metadata, + STATE_DICT_TYPE, + TensorStorageMetadata +) +from torch.distributed.checkpoint._traverse import set_element # type: ignore[attr-defined] +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner # type: ignore[attr-defined] +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict # type: ignore[attr-defined] +from transformers import AutoConfig, AutoTokenizer +import safetensors.torch + +try: + from maester.models import models_config +except ImportError: # pragma: no cover - fallback for standalone use + models_config = {} + + +def _find_job_config(start_path: str): + """Search upwards from the checkpoint directory for a job config.""" + path = Path(start_path).resolve() + for current in (path, *path.parents): + candidate = current / "config.json" + if candidate.is_file(): + try: + with candidate.open("r", encoding="utf-8") as handle: + return json.load(handle), candidate + except json.JSONDecodeError: + print(f"Warning: failed to parse {candidate}, ignoring.") + return None, None + + +def _load_model_args(job_config: Optional[dict]): + if not job_config: + return None + model_name = job_config.get("model_name") + flavor = job_config.get("flavor") + if not model_name or not flavor: + return None + config_store = models_config.get(model_name) + if not config_store: + return None + return config_store.get(flavor) + + +def _resolve_export_dtype(job_config: Optional[dict]) -> torch.dtype: + dtype = torch.bfloat16 + if not job_config: + return dtype + candidate = job_config.get("export_dtype") or job_config.get("mixed_precision_param") + if not isinstance(candidate, str): + return dtype + mapping = { + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + "float32": torch.float32, + "fp32": torch.float32, + } + return mapping.get(candidate.lower(), dtype) + + +def _infer_num_layers(state_dict: STATE_DICT_TYPE): + layer_ids = [] + for key in state_dict: + if "layers" not in key: + continue + match = re.search(r"layers.(\d+)", key) + if match: + layer_ids.append(int(match.group(1))) + if not layer_ids: + return None + return max(layer_ids) + 1 + + +def _infer_head_counts(state_dict: STATE_DICT_TYPE, hf_config, model_args, hidden_size: int): + num_heads = None + if model_args is not None: + num_heads = getattr(model_args, "n_heads", None) + if num_heads is None: + num_heads = getattr(hf_config, "num_attention_heads", None) + if num_heads is None: + raise ValueError("Unable to determine number of attention heads; provide --base or ensure config.json is available.") + + kv_heads = None + wk_weight = state_dict.get('layers.0.attention.wk.weight') + if isinstance(wk_weight, torch.Tensor) and hidden_size % num_heads == 0: + head_dim = hidden_size // num_heads + if head_dim and wk_weight.shape[0] % head_dim == 0: + kv_heads = wk_weight.shape[0] // head_dim + + if not kv_heads: + if model_args is not None: + kv_heads = getattr(model_args, "n_kv_heads", None) + if not kv_heads: + kv_heads = getattr(hf_config, "num_key_value_heads", None) + if not kv_heads: + kv_heads = num_heads + + return num_heads, kv_heads + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model, such as + when converting a DCP checkpoint into a Torch save file. + + . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner + + .. warning:: + Because the entire state dict is initialized, It's recommended to only utilize + this LoadPlanner on a single rank or process to avoid OOM. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_up_planner( # type: ignore[override] + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata | None = None, + is_coordinator: bool = True, + ) -> None: + assert not state_dict + + # rebuild the state dict from the metadata + assert metadata is not None + for k, v in metadata.state_dict_metadata.items(): + if isinstance(v, TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("src", type=str, help="Path to the source DCP model") + parser.add_argument("dst", type=str, help="Path to the destination model") + parser.add_argument("--base", type=str, required=True, help="Path to HF model this is based on, also uses tokenizer unless --tokenizer is specified") # TODO: can we not do this?? + parser.add_argument("--tokenizer", type=str, default=None, help="Path to HF tokenizer this is based on") # TODO: can we not do this?? + parser.add_argument("--name", type=str, required=True, help="Name (variant) of the model checkpoint to load, e.g. step-1000") + parser.add_argument("--type", type=str, default="hf", choices=["hf", "pt"], help="Type of the destination model") + parser.add_argument("--upload", type=str, default=None, help="HF repo to upload to (name gets appended)") + args = parser.parse_args() + + src_dir = os.path.join(args.src, args.name) + dst_dir = os.path.join(args.dst, args.name) + if not os.path.isdir(src_dir): + raise RuntimeError(f"Source DCP {src_dir} does not exist") + sd: STATE_DICT_TYPE = {} + storage_reader = FileSystemReader(src_dir) + + print('Loading checkpoint...') + _load_state_dict( + sd, + storage_reader=storage_reader, + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + if 'model' in sd: # model-only checkpoints do not have this + print(f"Full checkpoint detected, extracting model weights only. All keys: {list(sd.keys())}") + sd = sd['model'] + sd = {k.replace('._orig_mod', ''): v for k, v in sd.items()} # fix '_orig_mod' thing... + print(f"Model keys: {list(sd.keys())}") + + job_config, job_config_path = _find_job_config(src_dir) + if job_config_path: + print(f"Detected job config at {job_config_path}") + model_args = _load_model_args(job_config) + if model_args and job_config: + print( + "Using model metadata: " + f"{job_config.get('model_name')}/{job_config.get('flavor')}" + ) + + if args.type == "hf": + # Build and save HF Config + print('#' * 30) + print('Saving HF Model Config...') + hf_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer if args.tokenizer else args.base) + dtype = _resolve_export_dtype(job_config) + hf_config = AutoConfig.from_pretrained(args.base) + hf_config.torch_dtype = dtype + + # ------------------------------------------------------------------ + # Align RoPE / max-context settings with Maester's ModelArgs/YARN + # ------------------------------------------------------------------ + if model_args is not None: + # Match RoPE base frequency if present + rope_theta = getattr(model_args, "rope_theta", None) + if rope_theta is not None: + # Many HF llama-style configs expose this directly + setattr(hf_config, "rope_theta", rope_theta) + + # Match maximum context length + max_seq_len = getattr(model_args, "max_seq_len", None) + if max_seq_len is not None: + # Standard HF field for context window + setattr(hf_config, "max_position_embeddings", max_seq_len) + + # If YARN-style extension is configured in Maester, mirror it via rope_scaling + original_ctx = getattr(model_args, "original_max_context_length", None) + if original_ctx is not None and max_seq_len is not None and max_seq_len > original_ctx: + factor = float(max_seq_len) / float(original_ctx) + # Use a generic rope_scaling dict that modern HF LLaMA implementations understand. + # The exact semantics (especially for type="yarn") are delegated to the modeling code. + rope_scaling = { + "type": "yarn", + "original_max_position_embeddings": int(original_ctx), + "factor": factor, + } + setattr(hf_config, "rope_scaling", rope_scaling) + + inferred_layers = _infer_num_layers(sd) + if inferred_layers is not None: + hf_config.num_hidden_layers = inferred_layers + hidden_size = sd['layers.0.attention.wq.weight'].shape[0] + hf_config.hidden_size = hidden_size + num_heads, kv_heads = _infer_head_counts(sd, hf_config, model_args, hidden_size) + hf_config.num_attention_heads = num_heads + hf_config.num_key_value_heads = kv_heads + hf_config.intermediate_size = sd['layers.0.feed_forward.w1.weight'].shape[0] + hf_config.vocab_size = sd['tok_embeddings.weight'].shape[0] + if hf_tokenizer is not None: + if hf_tokenizer.bos_token_id is not None: + hf_config.bos_token_id = hf_tokenizer.bos_token_id + if hf_tokenizer.eos_token_id is not None: + hf_config.eos_token_id = hf_tokenizer.eos_token_id + # Use custom Maester LLaMA wrapper in HF (trust_remote_code) + hf_config.auto_map = { + "AutoModelForCausalLM": "hf_maester_llama.MaesterLlamaForCausalLM" + } + hf_config.save_pretrained(dst_dir) + print(hf_config) + + # Extract and save the HF tokenizer + print('#' * 30) + print('Saving HF Tokenizer...') + if hf_tokenizer is not None: + hf_tokenizer.save_pretrained(dst_dir) + print(hf_tokenizer) + else: + print('Warning! No HF Tokenizer found!') + + # Copy custom modeling file into export directory + src_modeling = Path(__file__).parent / "hf_maester_llama.py" + if src_modeling.is_file(): + shutil.copy(src_modeling, Path(dst_dir) / "hf_maester_llama.py") + else: + print(f"Warning: {src_modeling} not found; HF model will not load without it.") + + # Extract the HF model weights + print('#' * 30) + print('Saving HF Model Weights...') + # Convert weights to desired dtype and prefix with "model." to match + # MaesterLlamaForCausalLM.base_model_prefix. + final_state: dict[str, torch.Tensor] = {} + for k, v in sd.items(): + if isinstance(v, torch.Tensor): + v = v.to(dtype=dtype) + final_state[f"model.{k}"] = v + + safetensors.torch.save_file( + final_state, + os.path.join(dst_dir, 'model.safetensors'), + metadata={"format": "pt"}, + ) + + print('#' * 30) + print(f'HF checkpoint folder successfully created at {dst_dir}.') + + if args.upload: + from huggingface_hub import HfApi + api = HfApi() + repo_id = f"{args.upload}-{args.name}" + + print( + f'Uploading {dst_dir} to HuggingFace Hub at {repo_id}' + ) + api.create_repo(repo_id=repo_id, + use_auth_token=True, + repo_type='model', + private=True, + exist_ok=False) + print('Repo created.') + + api.upload_folder(folder_path=dst_dir, + repo_id=repo_id, + use_auth_token=True, + repo_type='model', + ) + print('Folder uploaded.') + elif args.type == "pt": + torch.save(sd, dst_dir) + else: + raise ValueError(f"Unknown destination type {args.type}") diff --git a/scripts/convert/llama/hf_maester_llama.py b/scripts/convert/llama/hf_maester_llama.py new file mode 100644 index 0000000..5323d80 --- /dev/null +++ b/scripts/convert/llama/hf_maester_llama.py @@ -0,0 +1,431 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + + +# --------------------------------------------------------------------------- +# Minimal copy of Maester's LLaMA architecture with YARN RoPE, simplified +# for inference-only use in Hugging Face. +# +# This is derived from maester.models.llama.model. It always returns logits; +# loss is computed in the HF wrapper. +# --------------------------------------------------------------------------- + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + init_std: float = 0.02 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + norm_type: str = "rmsnorm" + + # YARN args + original_max_context_length: Optional[int] = None + + +def precompute_freqs_cis( + dim: int, + max_context_length: int, + theta: float = 10000.0, + device: str = "cuda", + original_max_context_length: Optional[int] = None, + beta_fast: float = 32.0, + beta_slow: float = 1.0, +) -> torch.Tensor: + end = max_context_length * 2 + + freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32)[: dim // 2] + / dim + ) + ) + + if ( + original_max_context_length is not None + and max_context_length > original_max_context_length + ): + seqlen = max_context_length + base = theta + factor = float(seqlen) / float(original_max_context_length) + + low, high = _find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_context_length, + ) + smooth = 1.0 - _linear_ramp_factor(low, high, dim // 2, device) + + freqs = freqs / factor * (1.0 - smooth) + freqs * smooth + + t = torch.arange(end, device=device, dtype=torch.float32) + + if ( + original_max_context_length is not None + and max_context_length > original_max_context_length + ): + scale_factor = original_max_context_length / float(max_context_length) + t = t * scale_factor + + freqs_scaled = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs_scaled), freqs_scaled) + return freqs_cis + + +def _find_correction_dim(num_rotations: float, dim: int, base: float, max_seq_len: int) -> float: + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + +def _find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int +) -> Tuple[int, int]: + low = math.floor(_find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(_find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + +def _linear_ramp_factor( + min_val: float, max_val: float, dim: int, device: str +) -> torch.Tensor: + if min_val == max_val: + max_val += 0.001 + linear_func = ( + torch.arange(dim, device=device, dtype=torch.float32) - min_val + ) / (max_val - min_val) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if freqs_cis.ndim != xq_.ndim: + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + norm_x = x.norm(2, dim=-1, keepdim=True) + return x / (norm_x / math.sqrt(x.shape[-1]) + self.eps) * self.weight + + +def create_norm(norm_type: str, dim: int, eps: float) -> nn.Module: + # For our purposes, only RMSNorm is needed. + return RMSNorm(dim=dim, eps=eps) + + +class Attention(nn.Module): + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + + self.attn_scale = 1.0 / math.sqrt(self.head_dim) + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + xq = xq.transpose(1, 2) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + output = F.scaled_dot_product_attention( + xq, xk, xv, is_causal=True, enable_gqa=True, scale=self.attn_scale + ) + output = output.transpose(1, 2).contiguous() + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=False) + + self.layers = nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = create_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + self._init_weights() + + # Needed so Hugging Face's loader can query parameter dtypes on this inner module + def get_parameter_or_buffer(self, name: str): + """ + Minimal equivalent of PreTrainedModel.get_parameter_or_buffer, so that + transformers.modeling_utils._infer_parameter_dtype can work when it + recurses into the inner `model` module. + """ + module: nn.Module = self + if "." in name: + parts = name.split(".") + for p in parts[:-1]: + module = getattr(module, p) + name = parts[-1] + if name in module._parameters: + return module._parameters[name] + if name in module._buffers: + return module._buffers[name] + raise AttributeError(f"No parameter or buffer named {name}") + + def _init_weights(self) -> None: + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + nn.init.normal_(self.tok_embeddings.weight, std=self.model_args.init_std) + for layer in self.layers.values(): + # Norms are already initialized; attention/ffn use default init + pass + nn.init.normal_(self.output.weight, std=self.model_args.init_std) + + def _precompute_freqs_cis(self) -> torch.Tensor: + max_context = self.model_args.max_seq_len + return precompute_freqs_cis( + dim=self.model_args.dim // self.model_args.n_heads, + max_context_length=max_context, + theta=self.model_args.rope_theta, + original_max_context_length=self.model_args.original_max_context_length, + ) + + def forward( + self, + tokens: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = self.tok_embeddings(tokens) + batch_size, seq_len = tokens.shape + + if position_ids is not None: + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + assert position_ids.shape[0] == batch_size and position_ids.shape[1] == seq_len + position_ids = position_ids.long().to(device=self.freqs_cis.device) + freqs_cis = self.freqs_cis[position_ids] + freqs_cis = freqs_cis.unsqueeze(2) + else: + freqs_cis = self.freqs_cis + + for layer in self.layers.values(): + h = layer(h, freqs_cis) + h = self.norm(h) + output = self.output(h) + return output + + +# --------------------------------------------------------------------------- +# Hugging Face wrapper +# --------------------------------------------------------------------------- + + +class MaesterLlamaForCausalLM(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + + def __init__(self, config: LlamaConfig): + super().__init__(config) + + dim = config.hidden_size + n_layers = config.num_hidden_layers + n_heads = config.num_attention_heads + n_kv_heads = getattr(config, "num_key_value_heads", n_heads) + vocab_size = config.vocab_size + rope_theta = getattr(config, "rope_theta", 10000.0) + max_seq_len = getattr(config, "max_position_embeddings", 2048) + init_std = getattr(config, "initializer_range", 0.02) + + original_ctx = None + rope_scaling = getattr(config, "rope_scaling", None) + if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling: + original_ctx = int(rope_scaling["original_max_position_embeddings"]) + + model_args = ModelArgs( + dim=dim, + n_layers=n_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + vocab_size=vocab_size, + rope_theta=rope_theta, + max_seq_len=max_seq_len, + original_max_context_length=original_ctx, + init_std=init_std, + ) + + self.model = Transformer(model_args) + self.lm_head = self.model.output + + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + # type: ignore[assignment] + self.model.tok_embeddings = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + logits = self.model(tokens=input_ids, position_ids=position_ids) # type: ignore[arg-type] + + loss = None + if labels is not None: + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + ) + + return CausalLMOutputWithPast( + loss=loss, # type: ignore[arg-type] + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + +