Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions maester/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Config(BaseSettings):
data_parallel_shard_degree: int = 8
data_parallel_replicate_degree: int = 1
tensor_parallel_degree: int = 1
context_parallel_degree: int = 1
expert_parallel_degree: int = 1
train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch
gradient_accumulation_steps: int = 1
Expand Down
2 changes: 1 addition & 1 deletion maester/models/deepseek/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,6 @@ def forward(
return output

@classmethod
def from_model_args(cls, model_args: DeepSeekModelArgs) -> "DeepSeekModel":
def from_model_args(cls, model_args: DeepSeekModelArgs, cp_device_mesh=None) -> "DeepSeekModel":
"""Initialize from model args (compatible with training loop)."""
return cls(model_args)
25 changes: 15 additions & 10 deletions maester/models/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import flex_attention as _flex_attention
from torch.distributed import DeviceMesh

from maester.log_utils import logger

Expand Down Expand Up @@ -285,7 +286,8 @@ class GemmaAttention(nn.Module):
def __init__(
self,
config: ModelArgs,
attn_type: str
attn_type: str,
cp_device_mesh: DeviceMesh | None
):
super().__init__()

Expand Down Expand Up @@ -448,13 +450,15 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
config: ModelArgs,
attn_type: str
attn_type: str,
cp_device_mesh: DeviceMesh | None
):
super().__init__()
self.attn_type = attn_type
self.self_attn = GemmaAttention(
config=config,
attn_type=attn_type
attn_type=attn_type,
cp_device_mesh=cp_device_mesh
)
self.mlp = GemmaMLP(
hidden_size=config.dim,
Expand Down Expand Up @@ -523,7 +527,7 @@ def init_weights(self, init_std: float):
self.mlp.init_weights(init_std)

class GemmaModel(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
Expand All @@ -535,7 +539,7 @@ def __init__(self, config: ModelArgs):
if config.attn_types is not None
else "global"
)
self.layers.append(Gemma2DecoderLayer(config, attn_type))
self.layers.append(Gemma2DecoderLayer(config, attn_type, cp_device_mesh))
self.norm = RMSNorm(config.dim, eps=config.rms_norm_eps)

def forward(
Expand Down Expand Up @@ -569,21 +573,22 @@ def init_weights(self, init_std: float):

class GemmaTextModel(nn.Module):
"""Text-only Gemma model compatible with training setup."""
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None = None):
super().__init__()
self.config = config
self.model_args = config # For compatibility with training code
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers


self.cp_device_mesh = cp_device_mesh
# Text embeddings
self.tok_embeddings = Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.dim
)

# Core transformer model
self.model = GemmaModel(config)
self.model = GemmaModel(config, cp_device_mesh=cp_device_mesh)

# Precompute RoPE frequencies following multimodal pattern
head_dim = config.head_dim
Expand Down Expand Up @@ -772,9 +777,9 @@ def forward(
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "GemmaTextModel":
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "GemmaTextModel":
"""Initialize from model args (compatible with training loop)."""
return cls(model_args)
return cls(model_args, cp_device_mesh=cp_device_mesh)


class Gemma3MultiModalModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion maester/models/glm4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _process_hidden_states(
return hidden_states

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Glm4MoeTextModel":
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh=None) -> "Glm4MoeTextModel":
"""Initialize from model args (compatible with training loop)."""
return cls(model_args)

Expand Down
9 changes: 9 additions & 0 deletions maester/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@
max_seq_len=4096,
vocab_size=64256,
),
"Comma7B-32k": ModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
rope_theta=100000.0,
max_seq_len=32768,
vocab_size=64256,
original_max_context_length=4096,
),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down
158 changes: 144 additions & 14 deletions maester/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import DeviceMesh

from maester.models.llama.tied_linear import TiedLinear
from maester.models.norms import create_norm
Expand Down Expand Up @@ -48,6 +49,11 @@ class ModelArgs:
mup_output_alpha: float = 1.0
mup_width_mul: float = 1.0 # = width / base_width

# YARN (Yet Another RoPE extensioN) context extension:
# set ``original_max_context_length`` to the model's *old* context length
# when increasing ``max_seq_len`` so YARN RoPE scaling can be applied.
original_max_context_length: Optional[int] = None

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
"""
Calculate the number of parameters and FLOPS per token.
Expand Down Expand Up @@ -84,29 +90,147 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
return nparams, num_flops_per_token


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
def precompute_freqs_cis(
dim: int,
max_context_length: int,
theta: float = 10000.0,
device: str = "cuda",
original_max_context_length: Optional[int] = None,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
) -> torch.Tensor:
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
Supports YaRN (Yet another RoPE extensioN) scaling for context window extension,
following the implementation pattern used in torchtitan / DeepSeek-V3.

This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.

Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
dim (int): Dimension of the frequency tensor (per-head hidden size).
max_context_length (int): Target maximum context length for inference.
theta (float, optional): RoPE base. Defaults to 10000.0.
device (str): Device to create tensors on. Defaults to "cuda".
original_max_context_length (Optional[int]): Original training
context length for YaRN. If None, YaRN is disabled and standard RoPE is used.
beta_fast (float): YaRN hyperparameter controlling the fast-rotating band.
beta_slow (float): YaRN hyperparameter controlling the slow-rotating band.

Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# End index for precomputing frequencies (typically a small safety
# margin above the maximum context length).
end = max_context_length * 2

# Basic RoPE frequency calculation (per torchtitan / DeepSeek-V3 style)
freqs = 1.0 / (
theta
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)[
: (dim // 2)
]
/ dim
)
)

# YaRN scaling for extended context. YaRN is used to extend the context length
# after pre-training. We derive the scaling factor from the ratio between the
# target max context and the original training context.
if (
original_max_context_length is not None
and max_context_length > original_max_context_length
):
# seqlen here corresponds to the *target* context window (before the 2x safety margin)
seqlen = max_context_length
base = theta

# How much we are extending the context window compared to training.
factor = float(seqlen) / float(original_max_context_length)

# Compute the band of dimensions where we apply the smooth correction,
# using the same helpers as the torchtitan implementation.
low, high = _find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_context_length,
)
smooth = 1.0 - _linear_ramp_factor(low, high, dim // 2, device)

# Blend between the down-scaled and original frequencies.
# Outside the [low, high] band, we mostly use the scaled version; inside the
# band, we gradually recover the original frequencies.
freqs = freqs / factor * (1.0 - smooth) + freqs * smooth

# Positions we will precompute for (may be larger than the actual max context
# to give some safety margin).
t = torch.arange(end, device=device, dtype=torch.float32)

if (
original_max_context_length is not None
and max_context_length > original_max_context_length
):
# Compress the target context range [0, max_seq_len)
# into the original range [0, original_max_context_length).
# Note: we intentionally base the scale factor on the *target* context,
# not on `end`, so that changing the safety margin does not change the
# effective RoPE scaling.
scale_factor = (
original_max_context_length
/ float(max_context_length)
)
t = t * scale_factor

freqs_scaled = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs_scaled), freqs_scaled)
return freqs_cis


def _find_correction_dim(
num_rotations: float, dim: int, base: float, max_seq_len: int
) -> float:
"""
Compute the correction dimension for a given number of rotations
in the rotary positional embedding (YaRN helper).
"""
return (
dim
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
/ (2 * math.log(base))
)


def _find_correction_range(
low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int
) -> Tuple[int, int]:
"""
Compute the range of correction dimensions for rotary positional embeddings.
Mirrors torchtitan's YaRN implementation.
"""
low = math.floor(_find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(_find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim - 1)


def _linear_ramp_factor(
min_val: float, max_val: float, dim: int, device: str
) -> torch.Tensor:
"""
Linear ramp function used to smoothly blend scaled and unscaled frequencies.
"""
if min_val == max_val:
max_val += 0.001
linear_func = (
torch.arange(dim, device=device, dtype=torch.float32) - min_val
) / (max_val - min_val)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Reshape frequency tensor for broadcasting it with another tensor.
Expand Down Expand Up @@ -452,12 +576,17 @@ def init_weights(self):
nn.init.normal_(self.output.weight, std=self.model_args.init_std)

def _precompute_freqs_cis(self) -> torch.Tensor:
# We always precompute frequencies up to ``max_seq_len``. If
# ``original_max_context_length`` is set and
# ``max_seq_len`` exceeds it, YARN-style scaling is applied inside
# ``precompute_freqs_cis``.
max_context = self.model_args.max_seq_len

return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
self.model_args.rope_theta,
dim=self.model_args.dim // self.model_args.n_heads,
max_context_length=max_context,
theta=self.model_args.rope_theta,
original_max_context_length=self.model_args.original_max_context_length,
)

def forward(
Expand Down Expand Up @@ -528,12 +657,13 @@ def forward(
return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.

Args:
model_args (ModelArgs): Model configuration arguments.
cp_device_mesh (Optional[DeviceMesh]): Device mesh for context parallelism.

Returns:
Transformer: Transformer model.
Expand Down
Loading