From c9a0e453604bce1b46081191720f3a7cc0bd55bd Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 26 Feb 2026 23:38:34 -0500 Subject: [PATCH 1/2] Add Trinity model family (AfmoeForCausalLM) contrib implementation Unified NxDI implementation supporting all three Arcee AI Trinity sizes (Nano ~6B, Mini ~26B, Large ~250B) from a single modeling_trinity.py. Validated on SDK 2.27 (NxDI 0.7.15063, neuronx-cc 2.22.12471): - Nano: inf2.8xlarge (TP=1) and trn2.3xlarge (TP=2) - Mini: trn2.3xlarge (TP=4) - Large: trn2.48xlarge (TP=64) --- contrib/models/Trinity/README.md | 339 +++++ contrib/models/Trinity/src/__init__.py | 9 + .../models/Trinity/src/modeling_trinity.py | 1328 +++++++++++++++++ contrib/models/Trinity/test/__init__.py | 0 .../Trinity/test/integration/__init__.py | 0 .../Trinity/test/integration/test_model.py | 337 +++++ contrib/models/Trinity/test/unit/__init__.py | 0 7 files changed, 2013 insertions(+) create mode 100644 contrib/models/Trinity/README.md create mode 100644 contrib/models/Trinity/src/__init__.py create mode 100644 contrib/models/Trinity/src/modeling_trinity.py create mode 100644 contrib/models/Trinity/test/__init__.py create mode 100644 contrib/models/Trinity/test/integration/__init__.py create mode 100644 contrib/models/Trinity/test/integration/test_model.py create mode 100644 contrib/models/Trinity/test/unit/__init__.py diff --git a/contrib/models/Trinity/README.md b/contrib/models/Trinity/README.md new file mode 100644 index 00000000..6025a050 --- /dev/null +++ b/contrib/models/Trinity/README.md @@ -0,0 +1,339 @@ +# Contrib Model: Trinity + +NeuronX Distributed Inference implementation of the Trinity model family (AfmoeForCausalLM) from Arcee AI. A single unified implementation supports all three model sizes. + +## Model Family + +| Model | HuggingFace ID | Total Params | Active Params | Instance | +|-------|----------------|-------------|---------------|----------| +| **Nano** | `arcee-ai/Trinity-Nano-Preview` | ~6B | ~1B | inf2.8xlarge / trn2.3xlarge | +| **Mini** | `arcee-ai/Trinity-Mini` | ~26B | ~4.5B | trn2.3xlarge (TP=4) | +| **Large** | `arcee-ai/Trinity-Large-Preview` | ~250B | ~15B | trn2.48xlarge (TP=64) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Nano | Mini | Large | +|---------|------|------|-------| +| Layers | 56 (2 dense + 54 MoE) | 32 (2 dense + 30 MoE) | 60 (6 dense + 54 MoE) | +| Hidden Size | 1024 | 2048 | 3072 | +| Attention Heads | 8 | 32 | 48 | +| KV Heads (GQA) | 2 | 4 | 8 | +| Head Dim | 128 | 128 | 128 | +| Experts per MoE layer | 128 | 128 | 256 | +| Active Experts (TopK) | 8 | 8 | 4 | +| Shared Experts | 1 | 1 | 1 | +| Dense Intermediate | 3072 | 6144 | 12288 | +| MoE Intermediate | 256 | 1024 | 3072 | +| Sliding Window | 2048 | 2048 | 4096 | +| Max Position Embeddings | 131,072 | 131,072 | 262,144 | +| Vocabulary | 200,192 | 200,192 | 200,192 | +| Routing | Sigmoid + normalize (scale baked into weights) | +| Activation | SiLU gated MLP (`glu_type="glu"`) | +| Position Encoding | RoPE (sliding attention layers only) | +| Normalization | RMSNorm (4 per layer) | + +### Unique Architecture Features + +- **Mixed Attention:** Alternating sliding window and full attention (every 4th layer) +- **Gated Attention:** Sigmoid gate applied to attention output before o_proj +- **QK Normalization:** Per-head RMSNorm on Q and K +- **muP Scaling:** Embedding output scaled by hidden_size^0.5 +- **Expert Bias:** Learned bias added to routing scores for expert selection +- **Conditional RoPE:** Rotary embeddings applied only to sliding attention layers + +## Validation Results + +**Validated:** 2026-02-26 +**SDK:** NxDI 0.7.15063, neuronx-cc 2.22.12471, torch-neuronx 2.9.0.2.11, transformers 4.56.2 + +All results below are from the **unified `modeling_trinity.py`** (this code). + +### Trinity-Nano on trn2.3xlarge (TP=2, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 5.1 min | +| Load Time | 2.2 min | +| Forward Pass Latency | ~0.50s | + +**First-token predictions:** + +| Prompt | Top-1 Token | Logit | Top-5 | +|--------|-------------|-------|-------| +| "Hello, how are you?" | I | 17.75 | I, Hello, How | +| "Explain quantum computing in simple terms." | Answer | 21.00 | Answer, Quantum, What | +| "Write a Python function that calculates the Fibonacci sequence." | The | 24.75 | The, Your, Additionally | + +**Generation (5 tokens):** +- "Hello, how are you?" -> "I am fine, thank" +- "Explain quantum computing in simple terms." -> "Answer: Quantum computing uses" + +### Trinity-Mini on trn2.3xlarge (TP=4, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 4.9 min | +| Load Time | 4.1 min (from pre-compiled) | +| Forward Pass Latency | ~0.37s | + +**First-token predictions:** + +| Prompt | Top-1 Token | Logit | Top-5 | +|--------|-------------|-------|-------| +| "Hello, how are you?" | I | 20.12 | I, This, My | +| "Explain quantum computing in simple terms." | What | 20.75 | What, How, Quantum | +| "Write a Python function that calculates the Fibonacci sequence." | The | 28.00 | The, Your, It | + +**Generation (5 tokens):** +- "Hello, how are you?" -> "I'm fine, thank" +- "Explain quantum computing in simple terms." -> "What are the key differences" + +### Trinity-Nano on inf2.8xlarge (TP=1, no LNC) + +| Metric | Result | +|--------|--------| +| Compilation Time | Reused from trn2.3xlarge | +| Load Time | 47.7s | +| Forward Pass Latency | ~0.73s | + +**Note:** inf2.xlarge (16GB system RAM) cannot run Nano -- OOM killed at 15.3GB RSS during weight loading. inf2.8xlarge (123GB system RAM) works with TP=1. NxDI auto-converts GQA to MHA when `TP=1` and `num_kv_heads=2`. + +### Trinity-Large on trn2.48xlarge (TP=64, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 8.6 min | +| Load Time | 15.6 min | +| Forward Pass Latency | ~1.15s | + +**First-token predictions:** + +| Prompt | Top-1 Token | +|--------|-------------| +| "Hello, how are you?" | I | +| "Explain quantum computing in simple terms." | Quantum | +| "Write a Python function that calculates the Fibonacci sequence." | The | + +**Notes:** +- TP=32 is insufficient -- sharded weights consume ~23.5GB per logical NeuronCore, exceeding the ~24GB HBM per physical NC and leaving no room for scratchpad/KV cache. TP=64 (all 64 logical cores on trn2.48xlarge) is required. +- Model is ~516GB on disk (31 safetensors in bf16). Root EBS volume (600GB) is insufficient -- NVMe instance store is required for model storage (`/mnt/nvme/`). +- Set `TMPDIR`, `BASE_COMPILE_WORK_DIR`, and `NEURON_COMPILE_CACHE_URL` to NVMe paths to avoid filling root disk during compilation. + +## Usage + +### Trinity-Nano-Preview (~6B total, ~1B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Nano-Preview/" +compiled_path = "/path/to/compiled-nano/" + +neuron_config = MoENeuronConfig( + tp_degree=2, # Nano is small enough for TP=2 + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**Instance:** inf2.8xlarge (TP=1) or trn2.3xlarge (TP=2). Does NOT fit inf2.xlarge (16GB system RAM causes OOM). + +### Trinity-Mini (~26B total, ~4.5B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Mini/" +compiled_path = "/path/to/compiled-mini/" + +neuron_config = MoENeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**Instance:** trn2.3xlarge (TP=4). Does NOT fit inf2.8xlarge (~48GB bf16). + +### Trinity-Large-Preview (~250B total, ~15B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Large-Preview/" +compiled_path = "/path/to/compiled-large/" + +neuron_config = MoENeuronConfig( + tp_degree=64, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**Instance:** trn2.48xlarge only (TP=64, capacity block required, NVMe instance store for model storage). + +## Caveats + +1. **`padding_side="right"` required** -- NKI flash attention kernel does not support left-padding. Always set `padding_side="right"` on the tokenizer. + +2. **MoE v2 bf16 accumulation** -- The NxDI MoE v2 NKI kernel accumulates in bf16, causing ~23x more divergence per MoE layer compared to dense layers. Full-vocab cosine similarity is ~0.936, but top-1 token accuracy is preserved. A fix ticket has been filed. + +3. **`trust_remote_code=True` required** -- Trinity uses a custom `AfmoeForCausalLM` architecture not in standard transformers. The HuggingFace download requires `trust_remote_code=True`. + +4. **transformers version sensitivity** -- Use transformers 4.56.2 with SDK 2.27. Reference outputs may vary across transformers versions. + +5. **GLU type** -- Trinity uses `SiLU(gate) * up` which maps to NxDI's `glu_type="glu"`, NOT `"swiglu"`. This is handled automatically by the config class. + +6. **route_scale baked into weights** -- NxDI MoE v2 does not support `route_scale` natively. The scale is baked into routed expert `down_proj` weights during weight conversion. Shared expert weights are NOT scaled. + +7. **Gate padding at high TP** -- When `num_attention_heads` is not evenly divisible by `tp_degree` (e.g., Large at TP=64: 48/64), gate weights are padded with interleaved layout matching the Q projection. This is handled automatically during weight conversion. + +## Compatibility Matrix + +| Model | Instance | TP | LNC | Status | +|-------|----------|-----|-----|--------| +| Nano | inf2.xlarge | 1 | N/A | FAIL (16GB system RAM OOM) | +| Nano | inf2.8xlarge | 1 | N/A | Validated | +| Nano | trn2.3xlarge | 2 | 2 | Validated | +| Mini | inf2.8xlarge | -- | -- | Does NOT fit | +| Mini | trn2.3xlarge | 4 | 2 | Validated | +| Large | trn2.48xlarge | 32 | 2 | FAIL (HBM OOM per NC) | +| Large | trn2.48xlarge | 64 | 2 | Validated | + +### Minimum Requirements by Model Size + +| Model | Min HBM | Min TP | Min Instance | +|-------|---------|--------|-------------| +| Nano | ~12GB bf16 | 1 | inf2.8xlarge (123GB system RAM required) | +| Mini | ~48GB bf16 | 4 | trn2.3xlarge | +| Large | ~500GB bf16 | 64 | trn2.48xlarge (capacity block, NVMe storage) | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| NxDI | 0.7.15063 | +| neuronx-cc | 2.22.12471 | +| torch-neuronx | 2.9.0.2.11 | +| torch | 2.9.0 | +| transformers | 4.56.2 | +| Venv | `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/` | + +## Testing + +```bash +# Set paths for your model +export TRINITY_MODEL_PATH="/path/to/model" +export TRINITY_COMPILED_PATH="/path/to/compiled" + +# Run integration tests +pytest test/integration/test_trinity.py --capture=tee-sys + +# Or run directly +python test/integration/test_trinity.py +``` + +**Prerequisites:** +- Pre-compiled model at `TRINITY_COMPILED_PATH` +- HuggingFace model weights downloaded with `trust_remote_code=True` +- Appropriate instance for model size (see Compatibility Matrix) + +## Key Porting Challenges + +This model required solving several non-trivial porting challenges: + +1. **GLU type mismatch:** Trinity uses `SiLU(gate)*up` which maps to NxDI's `"glu"` type, NOT `"swiglu"` (`gate*SiLU(gate)*up`). +2. **Gated attention:** Trinity applies `sigmoid(gate(input))` to attention output before o_proj. Solved via inline override of attention forward methods (required for Neuron tracer compatibility). +3. **Dual intermediate sizes:** Dense layers use `intermediate_size`, MoE experts use `moe_intermediate_size`. Config swaps values for MoE module compatibility. +4. **route_scale not supported by NxDI MoE v2:** Baked into expert `down_proj` weights during conversion. +5. **expert_bias not supported by NxDI:** Created custom `RouterTopKWithBias` subclass. +6. **Conditional RoPE:** Only sliding attention layers get rotary embeddings. +7. **Mixed attention masks:** Framework provides both global and local masks; decoder layer selects based on layer type. +8. **Gate weight padding at high TP:** Interleaved padding matching Q projection layout (prevents wrong-head gating on 54/64 cores). +9. **Shared expert weight loading:** Standalone module for reliable weight mapping vs NxDI built-in shared expert handling. + +## NKI Kernels + +The NxDI framework uses several NKI (Neuron Kernel Interface) kernels during Trinity compilation and inference. These are hardware-accelerated kernels that execute directly on Neuron cores. + +| Kernel | Source | Purpose | +|--------|--------|---------| +| **Flash Attention (Context Encoding)** | `neuronxcc.nki._pre_prod_kernels.attn_fwd` | Full-sequence attention during context encoding (prompt processing). Fused QKV attention with causal masking and sliding window support. | +| **Flash Attention ISA** | `neuronxcc.nki.kernels.attention.attention_isa_kernel` | ISA-level flash attention implementation used as BIR (Built-in Runtime) fallback for context encoding. | +| **Token Gen Attention** | `neuronxcc.nki._private_kernels.attention.attention_tkg_fwd_isa_kernel` | Single-token attention with KV cache lookup during autoregressive token generation. | +| **Token Gen Attention Block (Fused)** | `neuronxcc.nki._pre_prod_kernels.attention_token_gen.llama3_nki_attention_block_token_gen_kernel` | Fused kernel combining attention + RMSNorm + residual connection for token generation. Used when `attn_block_tkg_nki_kernel_enabled` is true. | +| **Blockwise Matmul (MoE Experts)** | `neuronx_distributed.modules.moe.blockwise.BlockwiseMatmulNKIFunc` | Expert MLP computation in MoE layers (gate, up, down projections). Handles sparse expert dispatch with token routing. **Note:** Accumulates in bf16, causing slightly higher numerical divergence vs CPU reference. | +| **Custom RMSNorm** | `neuronx_distributed_inference.modules.custom_calls.CustomRMSNorm` | Hardware-accelerated RMSNorm via `AwsNeuronRmsNorm` custom call. Used 4 times per decoder layer (input_norm, post_attn_norm, pre_ff_norm, post_ff_norm). | +| **Cumsum** | `neuronxcc.nki.kernels.cumsum` | Attention mask computation for causal mask prefix sums. Used in both context encoding and token generation paths. | +| **Router TopK** | `neuronx_distributed.kernels.router_topk_kernel` | Expert selection in MoE routing -- selects top-k experts from sigmoid routing scores. Used once per MoE layer. | + +### NKI Kernel Interaction with Trinity-Specific Features + +- **Gated attention bypass:** When NKI fused attention block kernels are enabled (`attn_block_tkg_nki_kernel_enabled` or `attn_block_cte_nki_kernel_enabled`), Trinity's custom gated attention is bypassed and the base class fused kernel is used instead. The gated attention path is used when fused kernels are disabled. +- **MoE bf16 accumulation:** The blockwise matmul NKI kernel accumulates expert outputs in bf16 rather than fp32, which is the primary source of numerical divergence between Neuron and CPU reference outputs. Top-1 token accuracy is preserved. +- **Left-padding unsupported:** The NKI flash attention kernels require right-padding (`padding_side="right"`). Left-padding produces incorrect results. + +## Example Checkpoints + +- `arcee-ai/Trinity-Nano-Preview` (requires `trust_remote_code=True`) +- `arcee-ai/Trinity-Mini` (requires `trust_remote_code=True`) +- `arcee-ai/Trinity-Large-Preview` (requires `trust_remote_code=True`) + +## Maintainer + +Jim Burtoft + +**Last Updated:** 2026-02-27 diff --git a/contrib/models/Trinity/src/__init__.py b/contrib/models/Trinity/src/__init__.py new file mode 100644 index 00000000..309e2f52 --- /dev/null +++ b/contrib/models/Trinity/src/__init__.py @@ -0,0 +1,9 @@ +from .modeling_trinity import ( + TrinityInferenceConfig, + NeuronTrinityModel, + NeuronTrinityForCausalLM, + NeuronTrinityAttention, + NeuronTrinityMLP, + NeuronTrinitySharedExpert, + NeuronTrinityDecoderLayer, +) diff --git a/contrib/models/Trinity/src/modeling_trinity.py b/contrib/models/Trinity/src/modeling_trinity.py new file mode 100644 index 00000000..1706da94 --- /dev/null +++ b/contrib/models/Trinity/src/modeling_trinity.py @@ -0,0 +1,1328 @@ +#!/usr/bin/env python3 +""" +Unified NeuronX Distributed Inference implementation for the Trinity model family +(AfmoeForCausalLM) from Arcee AI. + +Supports all three Trinity sizes from a single codebase: +- Trinity-Nano-Preview (~6B total, ~1B active) +- Trinity-Mini (~26B total, ~4.5B active) +- Trinity-Large-Preview (~250B total, ~15B active) + +Architecture (shared across all sizes): +- AfmoeForCausalLM: Arcee Foundation Mixture of Experts +- Mixed attention: sliding_attention + full_attention (every 4th layer) +- Gated attention: gate_proj + sigmoid on attention output +- QK normalization: RMSNorm on Q and K per head +- Dual layer norms: pre/post for both attention and MLP (4 per layer) +- muP scaling: hidden_size**0.5 on input embeddings +- Sigmoid routing with normalization +- SiLU gated MLP (gate_proj, up_proj, down_proj) +- Expert bias on routing scores + +Key porting decisions: +- glu_type="glu" (NOT "swiglu") -- Trinity uses SiLU(gate)*up, which is NxDI's "glu" +- route_scale baked into routed expert down_proj weights (NxDI MoE v2 doesn't support it) +- muP scaling baked into embedding weights during conversion +- expert_bias handled via custom RouterTopKWithBias subclass +- Gated attention handled via inline override of attention forward methods +- Gate weight padding uses interleaved layout matching Q projection (for high TP) +""" + +import json +import os +import math +from typing import List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +# MoE v2 module (required for MoE layers) +try: + from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + from neuronx_distributed.modules.moe.routing import RouterTopK + + MOE_V2_AVAILABLE = True +except ImportError: + MOE_V2_AVAILABLE = False + print("WARNING: moe_v2 not available, MoE layers will not work") + + +class RouterTopKWithBias(RouterTopK): + """RouterTopK with expert_bias support for Trinity. + + Trinity uses expert_bias to influence which experts are selected: + - Sigmoid scores are computed: scores = sigmoid(logits) + - For top-k selection: topk(scores + expert_bias) + - For actual routing weights: gather scores at selected indices (no bias) + + The bias only affects WHICH experts are selected, not their weights. + """ + + def __init__(self, expert_bias_size, **kwargs): + super().__init__(**kwargs) + self.register_buffer( + "expert_bias", + torch.zeros(expert_bias_size, dtype=torch.float32), + ) + + def forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + + # Top-k selection with expert_bias added to scores. + scores_for_selection = expert_affinities.float() + self.expert_bias.float() + _, expert_index = torch.topk(scores_for_selection, self.top_k) + expert_index = expert_index.detach().to(dtype=torch.long) + + return router_logits, expert_affinities, expert_index + + +def initialize_moe_with_expert_bias(config): + """Initialize MoE module with expert_bias support.""" + moe = initialize_moe_module(config=config) + + old_router = moe.router + new_router = RouterTopKWithBias( + expert_bias_size=config.num_local_experts, + num_experts=old_router.num_experts, + top_k=old_router.top_k, + hidden_size=old_router.hidden_size, + dtype=old_router.dtype, + device=old_router.device, + act_fn=old_router.act_fn, + sequence_parallel_enabled=old_router.sequence_parallel_enabled, + sequence_dimension=old_router.sequence_dimension, + bias=old_router.bias, + apply_act_fn_over_topk=old_router.apply_act_fn_over_topk, + store_transposed_weights=old_router.store_transposed_weights, + ) + new_router.linear_router = old_router.linear_router + if hasattr(old_router, "weight_T"): + new_router.weight_T = old_router.weight_T + + moe.router = new_router + moe.eval() + return moe + + +def get_rmsnorm_cls(): + """Get the appropriate RMSNorm class based on execution mode.""" + if cpu_mode(): + + class StandardRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return (self.weight * hidden_states).to(input_dtype) + + return StandardRMSNorm + else: + return CustomRMSNorm + + +class TrinityInferenceConfig(InferenceConfig): + """Configuration for Trinity (AfmoeForCausalLM) inference. + + Handles all Trinity model sizes (Nano, Mini, Large) via config-driven values. + + IMPORTANT: initialize_moe_module reads config.intermediate_size for expert MLP + dimensions. Trinity has two different intermediate sizes: + - intermediate_size: used for dense MLP layers (first num_dense_layers) + - moe_intermediate_size: used for MoE expert MLPs + + We store the dense size as dense_intermediate_size and set intermediate_size to + moe_intermediate_size so that initialize_moe_module gets the correct value. + """ + + def __init__(self, neuron_config=None, **kwargs): + # Model architecture parameters from AfmoeConfig + self.vocab_size = kwargs.pop("vocab_size", 200192) + self.hidden_size = kwargs.pop("hidden_size", 2048) + + # CRITICAL: intermediate_size must be the MoE intermediate size for initialize_moe_module + dense_intermediate = kwargs.pop("intermediate_size", 6144) + moe_intermediate = kwargs.pop("moe_intermediate_size", 1024) + self.dense_intermediate_size = dense_intermediate + self.intermediate_size = moe_intermediate + self.moe_intermediate_size = moe_intermediate + + self.num_hidden_layers = kwargs.pop("num_hidden_layers", 32) + self.num_dense_layers = kwargs.pop("num_dense_layers", 2) + self.num_attention_heads = kwargs.pop("num_attention_heads", 32) + self.num_key_value_heads = kwargs.pop("num_key_value_heads", 4) + self.head_dim = kwargs.pop("head_dim", 128) + self.hidden_act = kwargs.pop("hidden_act", "silu") + self.max_position_embeddings = kwargs.pop("max_position_embeddings", 131072) + self.initializer_range = kwargs.pop("initializer_range", 0.02) + self.rms_norm_eps = kwargs.pop("rms_norm_eps", 1e-5) + self.use_cache = kwargs.pop("use_cache", True) + self.rope_theta = kwargs.pop("rope_theta", 10000.0) + self.rope_scaling = kwargs.pop("rope_scaling", None) + self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + + # MoE parameters + self.num_experts = kwargs.pop("num_experts", 128) + self.num_local_experts = kwargs.pop("num_local_experts", None) + if self.num_local_experts is None: + self.num_local_experts = self.num_experts + self.num_experts_per_tok = kwargs.pop("num_experts_per_tok", 8) + self.num_shared_experts = kwargs.pop("num_shared_experts", 1) + # IMPORTANT: Set n_shared_experts=0 for initialize_moe_module so the NxDI MoE + # module does NOT create its own SharedExperts. We handle shared experts ourselves + # in NeuronTrinityDecoderLayer to ensure proper weight loading. + self.n_shared_experts = 0 + self.num_expert_groups = kwargs.pop("num_expert_groups", 1) + self.num_limited_groups = kwargs.pop("num_limited_groups", 1) + self.score_func = kwargs.pop("score_func", "sigmoid") + self.route_norm = kwargs.pop("route_norm", True) + self.route_scale = kwargs.pop("route_scale", 1.0) + self.n_group = kwargs.pop("n_group", 1) + self.topk_group = kwargs.pop("topk_group", 1) + self.load_balance_coeff = kwargs.pop("load_balance_coeff", 0.001) + + # Attention patterns + self.global_attn_every_n_layers = kwargs.pop("global_attn_every_n_layers", 4) + self.sliding_window = kwargs.pop("sliding_window", 2048) + self.layer_types = kwargs.pop("layer_types", None) + + # Clamp sliding_window to seq_len if seq_len < sliding_window. + # The KV cache is sized by seq_len (via n_positions), and sliding window + # attention creates masks of size sliding_window. These must match. + if neuron_config is not None and hasattr(neuron_config, "seq_len"): + if neuron_config.seq_len < self.sliding_window: + print( + f"NOTE: Clamping sliding_window from {self.sliding_window} to " + f"{neuron_config.seq_len} to match seq_len" + ) + self.sliding_window = neuron_config.seq_len + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if bool((i + 1) % self.global_attn_every_n_layers) + else "full_attention" + for i in range(self.num_hidden_layers) + ] + + # muP + self.mup_enabled = kwargs.pop("mup_enabled", True) + + # Standard attributes + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.torch_dtype = kwargs.pop("torch_dtype", "bfloat16") + self.attention_bias = kwargs.pop("attention_bias", False) + self.output_attentions = kwargs.pop("output_attentions", False) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + + # Pop HF-specific keys not used by our config + kwargs.pop("auto_map", None) + kwargs.pop("architectures", None) + kwargs.pop("model_type", None) + kwargs.pop("transformers_version", None) + kwargs.pop("dtype", None) + kwargs.pop("use_grouped_mm", None) + + super().__init__(neuron_config=neuron_config, **kwargs) + + # Adjust num_local_experts for expert parallelism + if hasattr(self, "neuron_config") and self.neuron_config is not None: + ep_degree = getattr(self.neuron_config, "ep_degree", 1) + if ep_degree > 1: + self.num_local_experts = self.num_experts // ep_degree + + # Set MoE neuron config parameters + if hasattr(self, "neuron_config") and self.neuron_config is not None: + if not hasattr(self.neuron_config, "glu_mlp"): + self.neuron_config.glu_mlp = True + # Trinity uses SiLU(gate)*up which is NxDI's "glu" type, + # NOT "swiglu" which computes gate*SiLU(gate)*up + self.neuron_config.glu_type = "glu" + # Trinity uses sigmoid routing (not softmax) + if hasattr(self.neuron_config, "router_config"): + self.neuron_config.router_config.act_fn = "sigmoid" + + def add_derived_config(self): + """Add derived configuration parameters.""" + self.num_cores_per_group = 1 + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "vocab_size", + "max_position_embeddings", + "num_local_experts", + "num_experts_per_tok", + "intermediate_size", + "head_dim", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return MoENeuronConfig + + @classmethod + def from_pretrained(cls, model_path: str, **kwargs) -> "TrinityInferenceConfig": + neuron_config = kwargs.pop("neuron_config", None) + model_path = os.path.expanduser(model_path) + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found at {config_path}") + with open(config_path, "r") as f: + config_dict = json.load(f) + config_dict.update(kwargs) + config = cls(neuron_config=neuron_config, **config_dict) + return config + + +class NeuronTrinityAttention(NeuronAttentionBase): + """Trinity attention with QK norms, conditional RoPE, and gated output. + + Key differences from standard attention: + 1. QK norms: RMSNorm applied to Q and K per head before attention + 2. Conditional RoPE: Only applied for sliding_attention layers, not full_attention + 3. Gated output: output = o_proj(attn_out * sigmoid(gate_proj(input))) + + Gating strategy (inline override): + The Neuron tracer cannot follow tensor flow through mutable state, closures, + or dynamic method replacement. The ONLY working approach is to have the gate + computation INLINE in the same method that calls o_proj. + + We override standard_causal_attention_forward and windowed_attention_forward + to insert gate_values = sigmoid(attn_gate_proj(original_hidden_states)) + and apply attn_output = attn_output * gate_values before the o_proj call. + """ + + def __init__(self, config: TrinityInferenceConfig, layer_idx: int): + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + + # RoPE only for sliding attention layers + if is_sliding: + rotary_emb = RotaryEmbedding( + config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + else: + rotary_emb = None + + sliding_window = config.sliding_window if is_sliding else None + + # Per-head QK norm + rmsnorm_cls = get_rmsnorm_cls() + q_norm = rmsnorm_cls(config.head_dim, eps=config.rms_norm_eps) + k_norm = rmsnorm_cls(config.head_dim, eps=config.rms_norm_eps) + + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rope_theta=config.rope_theta if is_sliding else None, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_norm, + k_layernorm=k_norm, + sliding_window=sliding_window, + ) + + self.layer_idx = layer_idx + self.is_sliding = is_sliding + + # Gated attention: gate_proj applied before o_proj. + # Must match the actual per-rank attention output size from NxDI. + # When num_attention_heads is not divisible by TP, NxDI pads to + # ceil(num_heads/tp) heads per rank. We must match that padding. + tp_degree = config.neuron_config.tp_degree + heads_per_rank = math.ceil(config.num_attention_heads / tp_degree) + padded_total_heads = heads_per_rank * tp_degree + gate_output_size = padded_total_heads * config.head_dim + + self.attn_gate_proj = ColumnParallelLinear( + config.hidden_size, + gate_output_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + + def _apply_gated_o_proj(self, attn_output, gate_hidden_states, adapter_ids=None): + """Apply gating then o_proj, all inline for Neuron tracing. + + This method MUST be called from within the same forward pass where + gate_hidden_states is a live tensor in the traced graph. + """ + gate_values = torch.sigmoid(self.attn_gate_proj(gate_hidden_states)) + attn_output = attn_output * gate_values + return self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + def standard_causal_attention_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + kv_mgr=None, + get_kv_per_layer=False, + update_kv_per_layer=False, + residual=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + """Override base class to insert gating before o_proj. + + Copied from NeuronAttentionBase.standard_causal_attention_forward with + one change: the o_proj call is replaced with _apply_gated_o_proj. + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBaseOutput, + ) + + use_polar_compatible_rope = kwargs.get("use_polar_compatible_rope", False) + + # Save original hidden_states for gate computation BEFORE dtype conversion + gate_hidden_states = hidden_states + + original_dtype = hidden_states.dtype + hidden_states = hidden_states.to(self.torch_dtype) + seq_ids = kwargs.get("seq_ids") + is_context_parallel = past_key_value is None and self.cp_degree > 1 + is_data_parallel = past_key_value is not None and self.dp_degree > 1 + if is_context_parallel: + attention_mask, hidden_states, position_ids, cos_cache, sin_cache = ( + self._split_inputs_for_context_parallel( + attention_mask, hidden_states, position_ids, cos_cache, sin_cache + ) + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + get_dp_rank, + split_along_dim, + get_data_parallel_attention_dp_group, + gather_from_tensor_model_parallel_region_with_dim, + ) + + dp_rank = get_dp_rank( + self.rank_util.get_rank(), + self.tp_degree, + self.dp_degree, + self.neuron_config.switch_cc, + ) + hidden_states = split_along_dim( + hidden_states, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + attention_mask = split_along_dim( + attention_mask, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + position_ids = split_along_dim( + position_ids, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + if rotary_position_ids is None: + rotary_position_ids = position_ids + + if get_kv_per_layer: + assert kv_mgr is not None + past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) + + is_token_gen = past_key_value is not None + + if windowed_context_encoding_window_idx >= 0: + is_token_gen = False + + if self.neuron_config.is_prefix_caching: + is_token_gen = is_token_gen and q_len < 128 + + # NKI kernel paths -- delegate to base class (no custom gating in fused kernels) + if self.attn_block_tkg_nki_kernel_enabled and is_token_gen: + return super().standard_causal_attention_forward( + gate_hidden_states.to(self.torch_dtype) + if is_context_parallel or is_data_parallel + else gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + if ( + self.attn_block_cte_nki_kernel_enabled + and not is_token_gen + and not self.neuron_config.is_prefix_caching + ): + return super().standard_causal_attention_forward( + gate_hidden_states.to(self.torch_dtype) + if is_context_parallel or is_data_parallel + else gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + tkg_attn_kernel_fused_rope = ( + is_token_gen and self.attn_tkg_builtin_kernel_enabled + ) + + Q, K, V, cos_cache, sin_cache, residual = self.prep_qkv_tensors( + rotary_position_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + skip_rope=tkg_attn_kernel_fused_rope, + residual=residual, + use_polar_compatible_rope=use_polar_compatible_rope, + ) + + if is_token_gen: + if tkg_attn_kernel_fused_rope: + attn_output, K = self.attention_tokengen_kernel_builtin( + Q, + K, + V, + position_ids, + past_key_value, + attention_mask, + active_mask, + rotary_position_ids, + ) + else: + attn_output = self.attention_tokengen( + Q, + K, + V, + attention_mask, + position_ids, + past_key_value, + active_mask, + **kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_output, K, V = self.attention_context_encode( + Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask + ) + + # merge multi head hidden + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # *** GATED ATTENTION: apply gate BEFORE o_proj, all inline *** + attn_output = self._apply_gated_o_proj( + attn_output, gate_hidden_states, adapter_ids=adapter_ids + ) + + if self.k_cache_transposed: + K = K.permute(0, 1, 3, 2) + + kv = (K, V) + + if update_kv_per_layer: + assert kv_mgr is not None + kv = kv_mgr.update_kv_by_layer_id( + kv_per_layer=kv, + position_ids=position_ids, + **kwargs, + ) + + if is_context_parallel and not self.sequence_parallel_enabled: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_context_parallel_attention_cp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_data_parallel_attention_dp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=0, + process_group=get_data_parallel_attention_dp_group(), + ) + + attn_output = attn_output.to(original_dtype) + + return NeuronAttentionBaseOutput( + attn_output, kv, cos_cache, sin_cache, residual + ) + + def windowed_attention_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + kv_mgr=None, + get_kv_per_layer=False, + update_kv_per_layer=False, + residual=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + """Override base class to insert gating before o_proj. + + Copied from NeuronAttentionBase.windowed_attention_forward with + one change: the o_proj call is replaced with _apply_gated_o_proj. + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBaseOutput, + get_last_kv_window, + ) + + # Save original hidden_states for gate computation BEFORE any modifications + gate_hidden_states = hidden_states + + is_context_parallel = past_key_value is None and self.cp_degree > 1 + is_data_parallel = past_key_value is not None and self.dp_degree > 1 + + full_position_ids = position_ids.clone() + + if is_context_parallel: + attention_mask, hidden_states, position_ids, cos_cache, sin_cache = ( + self._split_inputs_for_context_parallel( + attention_mask, hidden_states, position_ids, cos_cache, sin_cache + ) + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + get_dp_rank, + split_along_dim, + ) + + dp_rank = get_dp_rank( + self.rank_util.get_rank(), + self.tp_degree, + self.dp_degree, + self.neuron_config.switch_cc, + ) + hidden_states = split_along_dim( + hidden_states, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + attention_mask = split_along_dim( + attention_mask, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + position_ids = split_along_dim( + position_ids, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + if rotary_position_ids is None: + rotary_position_ids = position_ids + + if get_kv_per_layer: + assert kv_mgr is not None + past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) + + is_token_gen = past_key_value is not None + + if windowed_context_encoding_window_idx >= 0: + is_token_gen = False + + # NKI kernel path -- delegate to base class (no gating) + if self.attn_block_tkg_nki_kernel_enabled and is_token_gen: + return super().windowed_attention_forward( + gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + tkg_attn_kernel_fused_rope = ( + is_token_gen and self.attn_tkg_builtin_kernel_enabled + ) + + Q, K, V, cos_cache, sin_cache, residual = self.prep_qkv_tensors( + rotary_position_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + skip_rope=tkg_attn_kernel_fused_rope, + residual=residual, + ) + + if is_token_gen: + attn_output = self.attention_tokengen( + Q, + K, + V, + attention_mask, + position_ids, + past_key_value, + active_mask, + **kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_output, K, V = self.attention_context_encode_windowed_attention( + Q, + K, + V, + q_len, + bsz, + attention_mask, + self.sliding_window, + past_key_value, + active_mask, + ) + K, V = get_last_kv_window( + self.sliding_window, + full_position_ids, + K, + V, + windowed_context_encoding_window_idx, + self.neuron_config.speculation_length, + ) + + # merge multi head hidden + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # *** GATED ATTENTION: apply gate BEFORE o_proj, all inline *** + attn_output = self._apply_gated_o_proj( + attn_output, gate_hidden_states, adapter_ids=adapter_ids + ) + + if self.k_cache_transposed: + K = K.permute(0, 1, 3, 2) + + kv = (K, V) + + if update_kv_per_layer: + assert kv_mgr is not None + kv = kv_mgr.update_kv_by_layer_id( + kv_per_layer=kv, + position_ids=position_ids, + **kwargs, + ) + + if is_context_parallel and not self.sequence_parallel_enabled: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_context_parallel_attention_cp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + return NeuronAttentionBaseOutput( + attn_output, kv, cos_cache, sin_cache, residual + ) + + +class NeuronTrinityMLP(nn.Module): + """Dense MLP for non-MoE layers (first num_dense_layers layers). + + Uses dense_intermediate_size, NOT the MoE intermediate_size. + """ + + def __init__(self, config: TrinityInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + intermediate = config.dense_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + intermediate, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = F.silu + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronTrinitySharedExpert(nn.Module): + """Shared expert MLP for MoE layers. + + Trinity has num_shared_experts=1. Each MoE layer has a shared expert whose + output is added to the routed expert output for every token. Uses the same + SiLU-gated MLP architecture as the dense layers but with moe_intermediate_size. + + Implemented as a standalone module (separate from NxDI's MoE SharedExperts) + to ensure reliable weight loading via standard ColumnParallelLinear/RowParallelLinear. + """ + + def __init__(self, config: TrinityInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + intermediate = config.moe_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + intermediate, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = F.silu + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronTrinityDecoderLayer(nn.Module): + """Trinity decoder layer with dual layer norms and conditional MoE. + + Structure: + - input_layernorm -> attention -> post_attention_layernorm -> residual + - pre_mlp_layernorm -> MLP/MoE -> post_mlp_layernorm -> residual + """ + + def __init__(self, config: TrinityInferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = NeuronTrinityAttention(config, layer_idx) + self.attention_type = config.layer_types[layer_idx] + + rmsnorm_cls = get_rmsnorm_cls() + self.input_layernorm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_mlp_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + + # MoE for layers >= num_dense_layers, dense MLP otherwise + self.moe_enabled = layer_idx >= config.num_dense_layers + if self.moe_enabled and MOE_V2_AVAILABLE: + self.mlp = initialize_moe_with_expert_bias(config=config) + # Shared expert: handled outside NxDI MoE to ensure reliable weight loading + if config.num_shared_experts > 0: + self.shared_expert = NeuronTrinitySharedExpert(config) + else: + self.shared_expert = None + else: + self.mlp = NeuronTrinityMLP(config) + self.shared_expert = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + normed = self.input_layernorm(hidden_states) + + # Select correct attention mask for this layer type. + # Mixed attention: local_mask for sliding layers, attention_mask for full layers. + local_mask = kwargs.pop("local_mask", None) + mask = local_mask + if self.attention_type == "full_attention" or local_mask is None: + mask = attention_mask + + attn_output, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=normed, + attention_mask=mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + attn_output = self.post_attention_layernorm(attn_output) + hidden_states = residual + attn_output + + # MLP with dual norms + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + + if self.moe_enabled and MOE_V2_AVAILABLE: + mlp_output = self.mlp(hidden_states, padding_mask)[0] + # Add shared expert output (applied to every token) + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + mlp_output = mlp_output + shared_output + hidden_states = mlp_output + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronTrinityModel(NeuronBaseModel): + """NeuronX Trinity base model (all sizes).""" + + def setup_attr_for_model(self, config: TrinityInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = getattr(config.neuron_config, "buckets", None) + + # Mixed attention: set sliding_window and has_mixed_attn so the framework + # creates both global and local (windowed) masks + self.sliding_window = getattr(config, "sliding_window", None) + self.has_mixed_attn = True + + def init_model(self, config: TrinityInferenceConfig): + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + self.mup_enabled = getattr(config, "mup_enabled", False) + self.mup_scale = math.sqrt(config.hidden_size) if self.mup_enabled else 1.0 + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList( + [ + NeuronTrinityDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + rmsnorm_cls = get_rmsnorm_cls() + self.norm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + + # Pad vocab_size to be divisible by TP degree for ColumnParallelLinear + tp_degree = config.neuron_config.tp_degree + padded_vocab = config.vocab_size + if padded_vocab % tp_degree != 0: + padded_vocab = ((padded_vocab // tp_degree) + 1) * tp_degree + self.padded_vocab_size = padded_vocab + self.actual_vocab_size = config.vocab_size + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + padded_vocab, + gather_output=False if self.on_device_sampling else True, + bias=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + ) + + +class NeuronTrinityForCausalLM(NeuronBaseForCausalLM): + """NeuronX wrapper for Trinity causal language models (all sizes). + + Supports: + - arcee-ai/Trinity-Nano-Preview (~6B total, ~1B active) + - arcee-ai/Trinity-Mini (~26B total, ~4.5B active) + - arcee-ai/Trinity-Large-Preview (~250B total, ~15B active) + """ + + _model_cls = NeuronTrinityModel + + @classmethod + def get_config_cls(cls): + return TrinityInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """Convert HuggingFace AfmoeForCausalLM state dict to NeuronX format. + + Key transformations: + 1. Remove 'model.' prefix from HF keys + 2. Rename QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + 3. Map attention gate_proj to attn_gate_proj (gated attention) + 4. Stack per-expert weights into [E, H, 2*I] gate_up_proj format + 5. Map router: router.gate.weight -> router.linear_router.weight + 6. Map shared expert weights to standalone shared_expert module + 7. Bake muP scaling into embedding weights + 8. Bake route_scale into routed expert down_proj weights + 9. Pad gate_proj weights with interleaved layout (when num_heads % TP != 0) + 10. Pad lm_head weights (when vocab_size % TP != 0) + """ + neuron_state_dict = {} + neuron_config = config.neuron_config + target_dtype = torch.bfloat16 + + has_model_prefix = any(k.startswith("model.") for k in state_dict.keys()) + + def strip_prefix(key): + if has_model_prefix and key.startswith("model."): + return key[6:] + return key + + # Direct mappings: embeddings, final norm, lm_head + for key, value in state_dict.items(): + stripped = strip_prefix(key) + + if stripped == "embed_tokens.weight": + embed_weight = value.to(target_dtype) + mup_enabled = getattr(config, "mup_enabled", False) + if mup_enabled: + mup_scale = math.sqrt(config.hidden_size) + embed_weight = embed_weight * mup_scale + neuron_state_dict["embed_tokens.weight"] = embed_weight + continue + if stripped == "norm.weight": + neuron_state_dict["norm.weight"] = value.to(target_dtype) + continue + if key == "lm_head.weight": + lm_weight = value.to(target_dtype) + # Pad lm_head to be divisible by TP degree + tp_degree = neuron_config.tp_degree + vocab_size = lm_weight.shape[0] + if vocab_size % tp_degree != 0: + padded_vocab = ((vocab_size // tp_degree) + 1) * tp_degree + pad_rows = padded_vocab - vocab_size + lm_weight = torch.cat( + [ + lm_weight, + torch.zeros( + pad_rows, lm_weight.shape[1], dtype=target_dtype + ), + ], + dim=0, + ) + neuron_state_dict["lm_head.weight"] = lm_weight + continue + + # Layer-by-layer conversion + num_layers = config.num_hidden_layers + num_experts = config.num_local_experts + moe_intermediate = config.moe_intermediate_size + hidden_size = config.hidden_size + num_dense_layers = getattr(config, "num_dense_layers", 2) + + for layer_idx in range(num_layers): + if has_model_prefix: + hf_prefix = f"model.layers.{layer_idx}" + else: + hf_prefix = f"layers.{layer_idx}" + neuron_prefix = f"layers.{layer_idx}" + + # Layer norms (4 per layer) + for norm_name in [ + "input_layernorm", + "post_attention_layernorm", + "pre_mlp_layernorm", + "post_mlp_layernorm", + ]: + hf_key = f"{hf_prefix}.{norm_name}.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.{norm_name}.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + + # Attention Q, K, V projections + for proj in ["q_proj", "k_proj", "v_proj"]: + hf_key = f"{hf_prefix}.self_attn.{proj}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.self_attn.qkv_proj.{proj}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # O projection + hf_key = f"{hf_prefix}.self_attn.o_proj.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.self_attn.o_proj.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + + # QK norm weights: q_norm -> q_layernorm, k_norm -> k_layernorm + for hf_norm, neuron_norm in [ + ("q_norm", "q_layernorm"), + ("k_norm", "k_layernorm"), + ]: + hf_key = f"{hf_prefix}.self_attn.{hf_norm}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.self_attn.{neuron_norm}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # Attention gate_proj (gated attention, Trinity-specific) + # CRITICAL: Must use INTERLEAVED padding matching Q projection layout. + # NxDI pads Q with maybe_pad_interleaved (REPLICATE_TO_TP_DEGREE), + # inserting zero heads between KV groups. The gate_proj output is + # element-wise multiplied with the attention output (which follows + # the Q head layout), so gate_proj MUST use the same interleaved + # padding pattern. Using tail padding causes cores to apply gate + # weights from the wrong head. + hf_key = f"{hf_prefix}.self_attn.gate_proj.weight" + if hf_key in state_dict: + gate_weight = state_dict[hf_key].to(target_dtype) + tp_degree = neuron_config.tp_degree + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + head_dim = config.head_dim + if num_heads % tp_degree != 0: + # Use interleaved padding matching Q layout + padded_total_heads = math.ceil(num_heads / tp_degree) * tp_degree + group_size = num_heads // num_kv_heads # Q heads per KV group + # Reshape gate to (num_heads, head_dim, hidden_size) + gate_3d = gate_weight.view(num_heads, head_dim, -1) + # Split into KV groups + groups = gate_3d.split(group_size, dim=0) + pad_per_group = (padded_total_heads - num_heads) // num_kv_heads + # Interleave with zero padding after each group + interleaved = [] + for group in groups: + interleaved.append(group) + interleaved.append( + torch.zeros( + pad_per_group, + head_dim, + gate_weight.shape[1], + dtype=target_dtype, + ) + ) + gate_weight = torch.cat(interleaved, dim=0).view( + padded_total_heads * head_dim, -1 + ) + neuron_state_dict[ + f"{neuron_prefix}.self_attn.attn_gate_proj.weight" + ] = gate_weight + + # MLP weights + if layer_idx < num_dense_layers: + # Dense layers (uses dense_intermediate_size) + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{hf_prefix}.mlp.{proj_name}.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.mlp.{proj_name}.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + else: + # MoE layers + # Router: router.gate.weight -> router.linear_router.weight + hf_router_key = f"{hf_prefix}.mlp.router.gate.weight" + if hf_router_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.mlp.router.linear_router.weight" + ] = state_dict[hf_router_key].to(target_dtype) + + # Expert bias (Trinity-specific routing parameter) + hf_bias_key = f"{hf_prefix}.mlp.expert_bias" + if hf_bias_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.mlp.router.expert_bias"] = ( + state_dict[hf_bias_key].to(torch.float32) + ) + + # Stack expert weights for NxDI MoE v2 format + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * moe_intermediate, dtype=target_dtype + ) + down_proj = torch.empty( + num_experts, moe_intermediate, hidden_size, dtype=target_dtype + ) + + all_experts_found = True + for e in range(num_experts): + gate_key = f"{hf_prefix}.mlp.experts.{e}.gate_proj.weight" + up_key = f"{hf_prefix}.mlp.experts.{e}.up_proj.weight" + down_key = f"{hf_prefix}.mlp.experts.{e}.down_proj.weight" + + if ( + gate_key in state_dict + and up_key in state_dict + and down_key in state_dict + ): + gate_w = state_dict[gate_key].to(target_dtype) + up_w = state_dict[up_key].to(target_dtype) + down_w = state_dict[down_key].to(target_dtype) + + gate_up_concat = torch.cat([gate_w, up_w], dim=0) + gate_up_proj[e] = gate_up_concat.T + down_proj[e] = down_w.T + else: + all_experts_found = False + break + + if all_experts_found: + # Bake route_scale into routed expert down_proj weights. + # NxDI MoE v2 does NOT support route_scale natively. + # Shared experts are NOT scaled. + route_scale = getattr(config, "route_scale", 1.0) + if route_scale != 1.0: + down_proj = down_proj * route_scale + + neuron_state_dict[ + f"{neuron_prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + neuron_state_dict[ + f"{neuron_prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + # Shared expert weights (mapped to standalone NeuronTrinitySharedExpert) + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{hf_prefix}.mlp.shared_experts.{proj_name}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.shared_expert.{proj_name}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # Rank utilities for tensor parallel + tp_degree = neuron_config.tp_degree + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + for i in range(num_layers): + neuron_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + return neuron_state_dict + + def get_compiler_args(self): + """Get compiler arguments for Trinity models.""" + return "--model-type=transformer -O1" diff --git a/contrib/models/Trinity/test/__init__.py b/contrib/models/Trinity/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Trinity/test/integration/__init__.py b/contrib/models/Trinity/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Trinity/test/integration/test_model.py b/contrib/models/Trinity/test/integration/test_model.py new file mode 100644 index 00000000..e0feba1f --- /dev/null +++ b/contrib/models/Trinity/test/integration/test_model.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +Integration tests for Trinity (AfmoeForCausalLM) NeuronX implementation. + +Supports all three Trinity model sizes (Nano, Mini, Large) via environment variables. + +Usage: + # Set paths for your model size + export TRINITY_MODEL_PATH="/path/to/model" + export TRINITY_COMPILED_PATH="/path/to/compiled" + + # Run tests + pytest test/integration/test_trinity.py --capture=tee-sys + +Prerequisites: + - Pre-compiled model at TRINITY_COMPILED_PATH + - HuggingFace model weights at TRINITY_MODEL_PATH (downloaded with trust_remote_code=True) + - Appropriate instance for model size (see README.md) +""" + +import os +import pytest +import torch +import json +from pathlib import Path +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + +# Import from src directory +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + + +# Configuration via environment variables +MODEL_PATH = os.environ.get( + "TRINITY_MODEL_PATH", + "/home/ubuntu/trinity/model/", +) +COMPILED_MODEL_PATH = os.environ.get( + "TRINITY_COMPILED_PATH", + "/home/ubuntu/trinity/compiled/", +) + + +def load_neuron_config_from_compiled(compiled_path: str): + """Load neuron configuration from compiled model's neuron_config.json.""" + config_path = Path(compiled_path) / "neuron_config.json" + + if not config_path.exists(): + raise FileNotFoundError(f"neuron_config.json not found: {config_path}") + + with open(config_path) as f: + config_data = json.load(f) + + if "neuron_config" in config_data: + return config_data["neuron_config"] + else: + return config_data + + +def create_model_for_inference(compiled_path: str, model_path: str): + """Create model for inference using compiled neuron_config.""" + neuron_config_dict = load_neuron_config_from_compiled(compiled_path) + + dtype_str = neuron_config_dict.get("torch_dtype", "torch.bfloat16") + if isinstance(dtype_str, str): + dtype = ( + getattr(torch, dtype_str.split(".")[1]) + if dtype_str.startswith("torch.") + else torch.bfloat16 + ) + else: + dtype = dtype_str + + neuron_config_kwargs = { + "tp_degree": neuron_config_dict.get("tp_degree", 4), + "batch_size": neuron_config_dict.get("batch_size", 1), + "seq_len": neuron_config_dict.get("seq_len", 2048), + "torch_dtype": dtype, + } + + neuron_config = MoENeuronConfig(**neuron_config_kwargs) + + try: + model_config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config + ) + except (TypeError, AttributeError): + model_config = TrinityInferenceConfig( + neuron_config, + load_config=load_pretrained_config(model_path), + ) + + model = NeuronTrinityForCausalLM(model_path, model_config) + return model, neuron_config + + +def generate_with_neuron_model(model, input_ids, max_new_tokens: int): + """Generate tokens using manual forward pass loop.""" + generated_ids = input_ids.clone() + + for _ in range(max_new_tokens): + seq_len = generated_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = model(generated_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + generated_ids = torch.cat([generated_ids, next_token], dim=-1) + + return generated_ids + + +@pytest.fixture(scope="module") +def compiled_model(): + """Load pre-compiled model.""" + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def test_model_loads(compiled_model): + """Test that model loads successfully (smoke test).""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + print("Smoke test passed - Model loaded successfully") + + +def test_model_generates(compiled_model, tokenizer): + """Test that model can generate text.""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=20 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text) > len(prompt), "Output should be longer than prompt" + print(f"Generation test passed") + print(f" Output: {output_text}") + + +def test_output_coherence(compiled_model, tokenizer): + """Test that output is coherent (not gibberish or repetitive).""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=30 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text.split()) > 3, "Output should have multiple words" + assert not _is_repetitive(output_text), "Output should not be repetitive" + + print(f"Coherence test passed") + print(f" Output: {output_text[:100]}...") + + +def test_top_token_valid(compiled_model, tokenizer): + """Test that the top predicted token is a valid, decodable token. + + Unlike model-specific tests, this does not check for a specific expected token + since different Trinity sizes produce different outputs. + """ + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + seq_len = inputs.input_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(inputs.input_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = compiled_model(inputs.input_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + top_token_id = torch.argmax(next_token_logits, dim=-1).item() + top_token = tokenizer.decode([top_token_id]).strip() + + print(f"Top token: '{top_token}' (id={top_token_id})") + print(f"Top logit: {next_token_logits[0, top_token_id].item():.2f}") + + # The top token should be a non-empty, printable string + assert len(top_token) > 0, f"Top token should be non-empty, got '{top_token}'" + assert top_token_id < tokenizer.vocab_size, "Token ID should be within vocab range" + print("Top token validation passed") + + +def _is_repetitive(text: str, max_repeat: int = 5) -> bool: + """Check if text has excessive repetition.""" + words = text.split() + if len(words) < 10: + return False + + for i in range(len(words) - max_repeat): + word = words[i] + if all(words[i + j] == word for j in range(max_repeat)): + return True + + new_text = text[-100:] if len(text) > 100 else text + if len(new_text) > 20: + char_counts = {} + for c in new_text: + char_counts[c] = char_counts.get(c, 0) + 1 + max_char_ratio = max(char_counts.values()) / len(new_text) + if max_char_ratio > 0.5: + return True + + return False + + +def test_performance_ttft(compiled_model, tokenizer): + """Test Time To First Token (TTFT) performance.""" + import time + + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + + # Warmup + for _ in range(3): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + + # Measure TTFT + times = [] + for _ in range(10): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + + start = time.perf_counter() + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + end = time.perf_counter() + + times.append((end - start) * 1000) + + avg_ttft = sum(times) / len(times) + print(f"TTFT: {avg_ttft:.2f}ms") + + +def test_performance_throughput(compiled_model, tokenizer): + """Test token generation throughput.""" + import time + + prompt = "Hello" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + num_tokens = 50 + + # Warmup + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) + + # Measure throughput + start = time.perf_counter() + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) + end = time.perf_counter() + + total_time = end - start + throughput = num_tokens / total_time + print(f"Throughput: {throughput:.2f} tok/s") + + +if __name__ == "__main__": + print("=" * 80) + print("Trinity (AfmoeForCausalLM) Integration Tests") + print("=" * 80) + print(f"Model path: {MODEL_PATH}") + print(f"Compiled path: {COMPILED_MODEL_PATH}") + + print(f"\nLoading compiled model from {COMPILED_MODEL_PATH}...") + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + print("Model loaded") + + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print("\n" + "=" * 80) + print("Running Tests") + print("=" * 80) + + print("\n1. Smoke Test (Model Loading)...") + test_model_loads(model) + + print("\n2. Generation Test...") + test_model_generates(model, tokenizer) + + print("\n3. Coherence Test...") + test_output_coherence(model, tokenizer) + + print("\n4. Top Token Validation...") + test_top_token_valid(model, tokenizer) + + print("\n" + "=" * 80) + print("All tests passed!") + print("=" * 80) diff --git a/contrib/models/Trinity/test/unit/__init__.py b/contrib/models/Trinity/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 466631bb3e62c06da98c9e8484c2f3e92c816cc4 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 28 Feb 2026 11:52:53 -0500 Subject: [PATCH 2/2] Fix mixed attention KV cache sizing and add max sequence length results Add layer_to_cache_size_mapping in setup_attr_for_model() to provide per-layer KV cache sizes for mixed attention models. Without this, KVCacheManager sizes all layers to sliding_window, causing a tensor shape mismatch in compute_for_token_gen when seq_len > sliding_window. Update README with validated max sequence lengths: - Nano TP=2: 40960, TP=4: 49152 (trn2.3xlarge) - Mini TP=4: 32768 (trn2.3xlarge) - Large TP=64: 30720 (trn2.48xlarge) All verified with actual token generation at max seq_len. --- contrib/models/Trinity/README.md | 46 +++++++++++++------ .../models/Trinity/src/modeling_trinity.py | 14 ++++++ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/contrib/models/Trinity/README.md b/contrib/models/Trinity/README.md index 6025a050..4c903e52 100644 --- a/contrib/models/Trinity/README.md +++ b/contrib/models/Trinity/README.md @@ -138,7 +138,7 @@ compiled_path = "/path/to/compiled-nano/" neuron_config = MoENeuronConfig( tp_degree=2, # Nano is small enough for TP=2 batch_size=1, - seq_len=2048, + seq_len=2048, # Max tested: 40960 (TP=2), 49152 (TP=4) torch_dtype=torch.bfloat16, ) @@ -172,7 +172,7 @@ compiled_path = "/path/to/compiled-mini/" neuron_config = MoENeuronConfig( tp_degree=4, batch_size=1, - seq_len=2048, + seq_len=2048, # Max tested: 32768 torch_dtype=torch.bfloat16, ) @@ -206,7 +206,7 @@ compiled_path = "/path/to/compiled-large/" neuron_config = MoENeuronConfig( tp_degree=64, batch_size=1, - seq_len=4096, + seq_len=4096, # Max tested: 30720 torch_dtype=torch.bfloat16, ) @@ -241,17 +241,37 @@ tokenizer = AutoTokenizer.from_pretrained( 7. **Gate padding at high TP** -- When `num_attention_heads` is not evenly divisible by `tp_degree` (e.g., Large at TP=64: 48/64), gate weights are padded with interleaved layout matching the Q projection. This is handled automatically during weight conversion. +8. **Mixed attention KV cache sizing** -- Trinity uses mixed attention (alternating sliding window and global attention layers). When `seq_len > sliding_window`, global attention layers need full `seq_len`-sized KV caches while sliding layers only need `sliding_window`-sized caches. The `layer_to_cache_size_mapping` in `setup_attr_for_model()` provides per-layer cache sizes to the framework. Without this, NxDI's `KVCacheManager` sizes ALL layers to `sliding_window`, causing a tensor shape mismatch in `compute_for_token_gen` where `prior_scores` (from undersized KV cache) doesn't match `attention_mask` (sized to `seq_len`). This fix is required for any `seq_len` above the model's sliding window. + +## Maximum Sequence Length + +Validated with token generation (5 tokens per prompt) at each max seq_len: + +| Model | Instance | TP | Max seq_len | Compile | Load | Gen Latency | +|-------|----------|-----|------------|---------|------|-------------| +| Nano | trn2.3xlarge | 2 | **40,960** | 1.5 min | 3.2 min | 2.4s/tok | +| Nano | trn2.3xlarge | 4 | **49,152** | 1.4 min | 1.4 min | 2.4s/tok | +| Mini | trn2.3xlarge | 4 | **32,768** | 0.9 min | 7.7 min | 2.4s/tok | +| Large | trn2.48xlarge | 64 | **30,720** | 1.6 min | 16.5 min | 2.9s/tok | + +Compile times above are for cache-hit runs. First compilation at each seq_len takes 5-25 min. + +Higher TP gives more headroom for KV cache (Nano TP=4 fits 49K vs 41K at TP=2). The failure mode at the limit is compilation timeout, not OOM. + +**Important:** The `layer_to_cache_size_mapping` fix in `modeling_trinity.py` is **required** for `seq_len > sliding_window` (2048 for Nano/Mini, 4096 for Large). Without it, token generation fails with a tensor shape mismatch in `compute_for_token_gen`. See Caveats section. + ## Compatibility Matrix -| Model | Instance | TP | LNC | Status | -|-------|----------|-----|-----|--------| -| Nano | inf2.xlarge | 1 | N/A | FAIL (16GB system RAM OOM) | -| Nano | inf2.8xlarge | 1 | N/A | Validated | -| Nano | trn2.3xlarge | 2 | 2 | Validated | -| Mini | inf2.8xlarge | -- | -- | Does NOT fit | -| Mini | trn2.3xlarge | 4 | 2 | Validated | -| Large | trn2.48xlarge | 32 | 2 | FAIL (HBM OOM per NC) | -| Large | trn2.48xlarge | 64 | 2 | Validated | +| Model | Instance | TP | LNC | Max seq_len | Status | +|-------|----------|-----|-----|------------|--------| +| Nano | inf2.xlarge | 1 | N/A | -- | FAIL (16GB system RAM OOM) | +| Nano | inf2.8xlarge | 1 | N/A | -- | Validated (not seq_len tested) | +| Nano | trn2.3xlarge | 2 | 2 | 40,960 | Validated | +| Nano | trn2.3xlarge | 4 | 2 | 49,152 | Validated | +| Mini | inf2.8xlarge | -- | -- | -- | Does NOT fit | +| Mini | trn2.3xlarge | 4 | 2 | 32,768 | Validated | +| Large | trn2.48xlarge | 32 | 2 | -- | FAIL (HBM OOM per NC) | +| Large | trn2.48xlarge | 64 | 2 | 30,720 | Validated | ### Minimum Requirements by Model Size @@ -336,4 +356,4 @@ The NxDI framework uses several NKI (Neuron Kernel Interface) kernels during Tri Jim Burtoft -**Last Updated:** 2026-02-27 +**Last Updated:** 2026-02-28 diff --git a/contrib/models/Trinity/src/modeling_trinity.py b/contrib/models/Trinity/src/modeling_trinity.py index 1706da94..90d19198 100644 --- a/contrib/models/Trinity/src/modeling_trinity.py +++ b/contrib/models/Trinity/src/modeling_trinity.py @@ -1016,6 +1016,20 @@ def setup_attr_for_model(self, config: TrinityInferenceConfig): self.sliding_window = getattr(config, "sliding_window", None) self.has_mixed_attn = True + # Per-layer KV cache sizing for mixed attention. + # Global attention layers need the full seq_len KV cache; sliding layers + # only need sliding_window. Without this, the KV cache manager sizes ALL + # layers to sliding_window, which causes a shape mismatch at token-gen time + # when seq_len > sliding_window (the global attention mask is sized to + # seq_len but K_prior is sized to sliding_window). + n_positions = config.neuron_config.n_positions + sw = self.sliding_window or n_positions + if sw < n_positions: + self.layer_to_cache_size_mapping = [ + sw if lt == "sliding_attention" else n_positions + for lt in config.layer_types + ] + def init_model(self, config: TrinityInferenceConfig): self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size