From 05872e4bb2d34560a024dcf2b042b0d2f792862b Mon Sep 17 00:00:00 2001 From: Deeptanshu Singh Date: Wed, 18 Feb 2026 14:56:04 -0500 Subject: [PATCH 1/2] Fix all four scaling multipliers for Granite --- .../models/granite-3.1-8b-instruct/README.md | 59 +++- .../src/modeling_granite.py | 305 +++++++++++++----- .../test/integration/test_model.py | 188 ++++++++--- 3 files changed, 412 insertions(+), 140 deletions(-) diff --git a/contrib/models/granite-3.1-8b-instruct/README.md b/contrib/models/granite-3.1-8b-instruct/README.md index 4e0e751..287f757 100644 --- a/contrib/models/granite-3.1-8b-instruct/README.md +++ b/contrib/models/granite-3.1-8b-instruct/README.md @@ -10,15 +10,27 @@ NeuronX Distributed Inference implementation of granite 3.1 8b instruct. ## Architecture Details -- **Layers:** Check model config -- **Hidden Size:** Check model config -- **Attention Heads:** Check model config -- **Vocabulary:** Check model config -- **Max Position Embeddings:** Check model config +- **Layers:** 32 +- **Hidden Size:** 4096 +- **Attention Heads:** 32 +- **Key-Value Heads:** 8 (GQA) +- **Vocabulary:** 49152 +- **Max Position Embeddings:** 131072 + +### Granite-Specific Scaling Factors + +Granite uses custom scaling factors that differ from standard Llama: + +| Parameter | Value | Description | +|-----------|-------|-------------| +| `embedding_multiplier` | 12.0 | Scales input embeddings after lookup | +| `attention_multiplier` | 0.0078125 (1/head_dim) | Custom attention scaling instead of 1/√head_dim | +| `residual_multiplier` | 0.22 | Scales residual connections | +| `logits_scaling` | 16.0 | Divides output logits | ## Validation Results -**Validated:** 2026-01-29 +**Validated:** 2026-02-06 **Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 ### Test Results @@ -26,29 +38,41 @@ NeuronX Distributed Inference implementation of granite 3.1 8b instruct. | Test | Status | Result | |------|--------|--------| | Smoke Test | ✅ PASS | Model loads successfully | -| Token Matching | ⚠️ LOW | **7.8% match** | -| TTFT (P50) | ✅ PASS | 19.44ms (threshold: 100ms) | -| Throughput | ✅ PASS | 106.00 tok/s (threshold: 10 tok/s) | +| Token Matching | ✅ PASS | **100% match** (64/64 tokens) | +| TTFT (P50) | ✅ PASS | ~20ms (threshold: 100ms) | +| Throughput | ✅ PASS | ~100 tok/s (threshold: 10 tok/s) | ### Performance Metrics | Metric | Value | |--------|-------| -| TTFT (P50) | 19.44ms | -| Throughput | 106.00 tokens/s | - +| TTFT (P50) | ~20ms | +| Throughput | ~100 tokens/s | **Status:** ✅ VALIDATED +## Critical Implementation Notes + +This implementation includes critical fixes for Granite's custom scaling: + +1. **Attention Multiplier Fix**: The `prep_qkv_tensors` method in `NeuronGraniteAttention` applies a correction factor to Q tensors to convert from the standard `1/√head_dim` scaling to Granite's `attention_multiplier`. + +2. **Embedding Multiplier**: Applied in `get_model_output` after embedding lookup (not to weights, to handle tied embeddings correctly). + +3. **Logits Scaling**: Applied via `ScaledColumnParallelLinear` which divides output by `logits_scaling`. + +4. **Residual Multiplier**: Applied in `NeuronGraniteDecoderLayer` to scale residual connections. + ## Usage ```python -from transformers import AutoTokenizer, GenerationConfig +import torch +from transformers import AutoTokenizer from neuronx_distributed_inference.models.config import NeuronConfig from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config # Import model classes from src -from src.modeling_granite_3_1_8b_instruct import Neurongranite318binstructForCausalLM, granite318binstructInferenceConfig +from src.modeling_granite import NeuronGraniteForCausalLM, GraniteInferenceConfig model_path = "/path/to/granite-3.1-8b-instruct/" compiled_model_path = "/path/to/compiled/" @@ -61,18 +85,19 @@ neuron_config = NeuronConfig( torch_dtype=torch.bfloat16, ) -config = granite318binstructInferenceConfig( +config = GraniteInferenceConfig( neuron_config, load_config=load_pretrained_config(model_path), ) # Compile and load -model = Neurongranite318binstructForCausalLM(model_path, config) +model = NeuronGraniteForCausalLM(model_path, config) model.compile(compiled_model_path) model.load(compiled_model_path) # Generate tokenizer = AutoTokenizer.from_pretrained(model_path) +inputs = tokenizer("Hello, how are you?", return_tensors="pt") # ... (see integration test for full example) ``` @@ -106,4 +131,4 @@ python3 test/integration/test_model.py Neuroboros Team - Annapurna Labs -**Last Updated:** 2026-01-29 +**Last Updated:** 2026-02-06 diff --git a/contrib/models/granite-3.1-8b-instruct/src/modeling_granite.py b/contrib/models/granite-3.1-8b-instruct/src/modeling_granite.py index dcb0b80..08e354a 100644 --- a/contrib/models/granite-3.1-8b-instruct/src/modeling_granite.py +++ b/contrib/models/granite-3.1-8b-instruct/src/modeling_granite.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# Adapted for NeuronX Distributed Inference. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +16,20 @@ """ NeuronX Distributed Inference implementation of Granite model. -This implementation ports the Granite model from: +FIXED VERSION - Addresses critical accuracy issues: +1. attention_multiplier: Applied by scaling Q before attention computation +2. logits_scaling: Applied by dividing logits after lm_head +3. embedding_multiplier: Applied in forward pass (not to weights for tied embeddings) Key differences from Llama: 1. embedding_multiplier: Scales input embeddings (default: 12.0) 2. logits_scaling: Scales output logits (default: 16.0) 3. residual_multiplier: Scales residual connections (default: 0.22) -4. attention_multiplier: Custom attention scaling (default: 0.0078125) +4. attention_multiplier: Custom attention scaling (default: 0.0078125 = 1/head_dim) """ import logging +import math from typing import List, Optional, Tuple, Type import torch @@ -237,21 +242,34 @@ class NeuronGraniteAttention(NeuronAttentionBase): """ Granite attention layer for NeuronX. - Key differences from Llama attention: - - Uses attention_multiplier instead of 1/sqrt(head_dim) for scaling + CRITICAL FIX: Granite uses attention_multiplier (0.0078125 = 1/head_dim) + instead of the standard 1/sqrt(head_dim) = 0.0884. - Inherits from NeuronAttentionBase which provides: - - Column parallel Q, K, V projections - - Row parallel output projection - - Rotary position embeddings - - KV cache management + The NeuronX attention kernels apply 1/sqrt(head_dim) scaling internally: + - Context encoding (perform_prefill): Q = Q / sqrt(head_dim) + - Token generation (compute_for_token_gen): scores = Q @ K^T / sqrt(head_dim) + + To convert from standard scaling to Granite's attention_multiplier: + - Standard: Q @ K^T / sqrt(head_dim) + - Granite: Q @ K^T * attention_multiplier + + We need to pre-scale Q by a correction factor: + - (Q * correction) / sqrt(head_dim) = Q * attention_multiplier + - correction = attention_multiplier * sqrt(head_dim) + - correction = 0.0078125 * sqrt(128) = 0.0078125 * 11.31 = 0.0884 """ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): # Get Granite-specific attention multiplier - # In Granite, scaling is attention_multiplier (e.g., 0.0078125) - # instead of the standard 1/sqrt(head_dim) - self.attention_multiplier = getattr(config, "attention_multiplier", 1.0 / (config.hidden_size // config.num_attention_heads) ** 0.5) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.attention_multiplier = getattr(config, "attention_multiplier", 1.0 / head_dim) + + # Compute the correction factor to convert from standard 1/sqrt(head_dim) to attention_multiplier + # The kernel applies: scores = Q @ K^T / sqrt(head_dim) + # We want: scores = Q @ K^T * attention_multiplier + # So: (Q * correction) @ K^T / sqrt(head_dim) = Q @ K^T * attention_multiplier + # correction = attention_multiplier * sqrt(head_dim) + self.q_scale_factor = self.attention_multiplier * math.sqrt(head_dim) # Initialize the base attention class super().__init__( @@ -260,30 +278,94 @@ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, - head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + head_dim=head_dim, rotary_emb=self._get_rope(config), num_cores_per_group=config.num_cores_per_group, qkv_bias=getattr(config, "attention_bias", False), o_bias=getattr(config, "attention_bias", False), rms_norm_eps=config.rms_norm_eps, ) - - # Store attention multiplier for use in attention computation - # Note: NeuronAttentionBase uses self.scaling which defaults to 1/sqrt(head_dim) - # We need to override the scaling used in attention computation def _get_rope(self, config: InferenceConfig): - """ - Get the rotary position embedding module for Granite. - - Granite uses standard RoPE without scaling. - """ + """Get the rotary position embedding module for Granite.""" head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) return RotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + skip_rope=False, + residual=None, + use_polar_compatible_rope=False, + ): + """ + Override prep_qkv_tensors to apply Granite's attention_multiplier. + + Since the flash attention kernel uses scale=1.0, we need to apply + the attention_multiplier ourselves by scaling Q. + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + move_heads_front, + ) + + # Get QKV projections through the base class's qkv_proj + Q, K, V, residual = self.get_qkv_proj()( + hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids, residual=residual + ) + + # Reshape to heads + bsz, q_len, _ = hidden_states.size() + if getattr(self, 'qkv_proj_sp_enabled', False): + q_len *= self.tensor_model_parallel_group.size() + + # BSHD -> BHSD layout + Q = move_heads_front(Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=None) + K = move_heads_front(K, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None) + V = move_heads_front(V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None) + + # Apply rotary embeddings + if not skip_rope and self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + + # CRITICAL FIX: Apply Granite's attention_multiplier by scaling Q + # The attention kernels compute: softmax((Q / sqrt(head_dim)) @ K^T) @ V + # But Granite wants: softmax(Q @ K^T * attention_multiplier) @ V + # + # To convert: (Q * correction) / sqrt(head_dim) @ K^T = Q @ K^T * attention_multiplier + # correction = attention_multiplier * sqrt(head_dim) = 0.0078125 * 11.31 = 0.0884 + Q = Q * self.q_scale_factor + + # Gather KV to full S when CP is enabled + if past_key_value is None and getattr(self, 'cp_degree', 1) > 1: + from neuronx_distributed.parallel_layers.mappings import gather_from_tensor_model_parallel_region_with_dim + from neuronx_distributed_inference.modules.attention.attention_process_groups import get_context_parallel_attention_cp_group + from neuronx_distributed_inference.modules.attention.utils import order_strided_tensor + from neuronx_distributed_inference.modules.attention.attention_base import FlashAttentionStrategy + + stacked_kv = torch.stack([K, V], dim=0) + stacked_kv = gather_from_tensor_model_parallel_region_with_dim( + stacked_kv, + gather_dim=3, + process_group=get_context_parallel_attention_cp_group(), + ) + if self.get_flash_attention_strategy_cp(q_len * self.cp_degree) == FlashAttentionStrategy.STRIDED_CONTEXT_PARALLEL_KERNEL: + stacked_kv = order_strided_tensor(stacked_kv, 3, self.cp_degree) + K, V = torch.unbind(stacked_kv, dim=0) + + return Q, K, V, cos_cache, sin_cache, residual class NeuronGraniteDecoderLayer(nn.Module): @@ -376,13 +458,46 @@ def forward( return outputs +class ScaledColumnParallelLinear(ColumnParallelLinear): + """ + ColumnParallelLinear that applies logits_scaling after the linear projection. + + This is needed for Granite which divides logits by logits_scaling (16.0) + after the lm_head projection. + """ + + def __init__(self, *args, logits_scaling: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.logits_scaling = logits_scaling + + def forward(self, x): + output = super().forward(x) + # Apply Granite's logits_scaling + return output / self.logits_scaling + + +class ScaledLinear(nn.Linear): + """ + Linear layer that applies logits_scaling after the linear projection. + For non-parallel mode (CPU testing). + """ + + def __init__(self, *args, logits_scaling: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.logits_scaling = logits_scaling + + def forward(self, x): + output = super().forward(x) + return output / self.logits_scaling + + class NeuronGraniteModel(NeuronBaseModel): """ Granite model for NeuronX. - Key differences from Llama: - - Input embeddings are scaled by embedding_multiplier (applied to weights at load time) - - Output logits are scaled by 1/logits_scaling + CRITICAL FIXES: + 1. embedding_multiplier: Applied in forward pass via get_model_output override + 2. logits_scaling: Applied via ScaledColumnParallelLinear in lm_head """ def setup_attr_for_model(self, config: InferenceConfig): @@ -395,7 +510,7 @@ def setup_attr_for_model(self, config: InferenceConfig): self.max_batch_size = config.neuron_config.max_batch_size self.buckets = config.neuron_config.buckets - # Granite-specific multipliers (stored for reference, applied during weight conversion) + # Granite-specific multipliers self.embedding_multiplier = getattr(config, "embedding_multiplier", 1.0) self.logits_scaling = getattr(config, "logits_scaling", 1.0) @@ -404,7 +519,7 @@ def init_model(self, config: InferenceConfig): self.padding_idx = getattr(config, "pad_token_id", 0) self.vocab_size = config.vocab_size - # Token embeddings (embedding_multiplier is applied to weights at load time) + # Token embeddings - embedding_multiplier is applied in forward(), not here if parallel_state.model_parallel_is_initialized(): self.embed_tokens = ParallelEmbedding( config.vocab_size, @@ -418,7 +533,8 @@ def init_model(self, config: InferenceConfig): tensor_model_parallel_group=get_tp_group(config), ) - self.lm_head = ColumnParallelLinear( + # CRITICAL FIX: Use ScaledColumnParallelLinear to apply logits_scaling + self.lm_head = ScaledColumnParallelLinear( config.hidden_size, config.vocab_size, gather_output=not self.on_device_sampling, @@ -426,14 +542,20 @@ def init_model(self, config: InferenceConfig): bias=False, pad=True, tensor_model_parallel_group=get_tp_group(config), + logits_scaling=self.logits_scaling, ) else: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, - self.padding_idx + self.padding_idx, + ) + self.lm_head = ScaledLinear( + config.hidden_size, + config.vocab_size, + bias=False, + logits_scaling=self.logits_scaling, ) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Decoder layers self.layers = nn.ModuleList( @@ -443,6 +565,59 @@ def init_model(self, config: InferenceConfig): # Final layer norm self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + def get_model_output( + self, + input_ids: torch.LongTensor = None, + seq_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + active_mask: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + prev_hidden: Optional[torch.FloatTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + rotary_position_ids: Optional[torch.LongTensor] = None, + update_cache: bool = False, + is_for_context_encoding: bool = False, + vision_embeddings: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.BoolTensor] = None, + local_attn_mask: Optional[torch.Tensor] = None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + """ + Override get_model_output to apply Granite's embedding_multiplier. + + Granite multiplies embeddings by embedding_multiplier (12.0) AFTER the embedding + lookup. This is critical for correct model behavior. + """ + # Apply Granite embedding_multiplier if we need to compute embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # Apply Granite's embedding_multiplier (12.0) + inputs_embeds = inputs_embeds * self.embedding_multiplier + + # Call parent's get_model_output with pre-computed embeddings + return super().get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + active_mask=active_mask, + inputs_embeds=inputs_embeds, # Pass scaled embeddings + prev_hidden=prev_hidden, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_ids, + update_cache=update_cache, + is_for_context_encoding=is_for_context_encoding, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + local_attn_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + class NeuronGraniteForCausalLM(NeuronBaseForCausalLM): """ @@ -465,72 +640,44 @@ def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) - """ Convert HuggingFace state dict to Neuron format. - Performs the following transformations: - 1. Adds rank_util.rank for tensor parallelism - 2. Applies Granite's embedding_multiplier to embedding weights - 3. Maps attention projection weights to NeuronAttentionBase structure: - - self_attn.q_proj.weight → self_attn.qkv_proj.q_proj.weight - - self_attn.k_proj.weight → self_attn.qkv_proj.k_proj.weight - - self_attn.v_proj.weight → self_attn.qkv_proj.v_proj.weight - - self_attn.o_proj.weight → self_attn.o_proj.o_proj.weight + IMPORTANT: The framework's preshard_hook in GroupQueryAttention_QKV and + GroupQueryAttention_O automatically handles the key renaming: + - q_proj.weight -> qkv_proj.q_proj.weight + - k_proj.weight -> qkv_proj.k_proj.weight + - v_proj.weight -> qkv_proj.v_proj.weight + - o_proj.weight -> o_proj.o_proj.weight - Args: - state_dict: HuggingFace model state dictionary - config: Model configuration - - Returns: - Neuron-compatible state dictionary + So we should NOT rename these keys here. + + IMPORTANT: For Granite with tie_word_embeddings=True: + - embedding_multiplier is applied in the forward pass, NOT to weights + - lm_head.weight is tied to embed_tokens.weight (same weights) + - logits_scaling is applied in the forward pass via ScaledColumnParallelLinear """ neuron_config = config.neuron_config num_layers = config.num_hidden_layers tp_degree = neuron_config.tp_degree - # Get Granite-specific multipliers - embedding_multiplier = getattr(config, "embedding_multiplier", 1.0) - - # Apply embedding_multiplier to embedding weights - # This is mathematically equivalent to multiplying the output of embed_tokens - if "embed_tokens.weight" in state_dict: - state_dict["embed_tokens.weight"] = state_dict["embed_tokens.weight"] * embedding_multiplier + # NOTE: Do NOT apply embedding_multiplier to weights! + # For tied weights, this would incorrectly scale the lm_head weights. + # Instead, embedding_multiplier is applied in the forward pass. - # Map attention projection weights to NeuronAttentionBase structure + # Add rank_util tensors required by NeuronAttentionBase for i in range(num_layers): - # Map QKV projections - for proj in ["q", "k", "v"]: - old_key = f"layers.{i}.self_attn.{proj}_proj.weight" - new_key = f"layers.{i}.self_attn.qkv_proj.{proj}_proj.weight" - if old_key in state_dict: - state_dict[new_key] = state_dict.pop(old_key) - - # Map output projection - old_o_key = f"layers.{i}.self_attn.o_proj.weight" - new_o_key = f"layers.{i}.self_attn.o_proj.o_proj.weight" - if old_o_key in state_dict: - state_dict[new_o_key] = state_dict.pop(old_o_key) - - # Add rank information for tensor parallelism in attention layers state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( 0, tp_degree, dtype=torch.int32 ) - # Add rank information for base model state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) return state_dict @staticmethod def update_state_dict_for_tied_weights(state_dict): - """ - Handle tied weights between embeddings and LM head. - - Granite uses tie_word_embeddings=True by default. - Note: The embedding_multiplier is already applied to embed_tokens.weight, - but we also need to apply 1/logits_scaling for the lm_head. - Since they share weights in HF, we need to be careful here. - - For tied weights, lm_head.weight = embed_tokens.weight (already scaled by embedding_multiplier) - The logits_scaling is typically applied in the forward pass, not to weights. - """ + """Handle tied weights between embeddings and LM head.""" + # Granite uses tie_word_embeddings=True + # The lm_head.weight should be the same as embed_tokens.weight + # Note: embedding_multiplier is applied in forward pass, not to weights if "embed_tokens.weight" in state_dict and "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() @@ -548,4 +695,6 @@ def get_config_cls(cls): "NeuronGraniteMLP", "NeuronGraniteAttention", "NeuronGraniteDecoderLayer", + "ScaledColumnParallelLinear", + "ScaledLinear", ] diff --git a/contrib/models/granite-3.1-8b-instruct/test/integration/test_model.py b/contrib/models/granite-3.1-8b-instruct/test/integration/test_model.py index 423d58b..78798a7 100755 --- a/contrib/models/granite-3.1-8b-instruct/test/integration/test_model.py +++ b/contrib/models/granite-3.1-8b-instruct/test/integration/test_model.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 """ Integration tests for granite-3.1-8b-instruct NeuronX implementation. + +Tests model compilation, loading, and inference accuracy/performance. +Follows the exact patterns from validate_model.py for consistency. """ import pytest @@ -15,7 +18,7 @@ # Import from src directory import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) -from modeling_granite import * +from modeling_granite import NeuronGraniteForCausalLM, GraniteInferenceConfig # Test configuration @@ -35,8 +38,7 @@ def load_neuron_config_from_compiled(compiled_path: str): if "neuron_config" in config_data: return config_data["neuron_config"] - else: - return config_data + return config_data def create_model_for_inference(compiled_path: str, model_path: str): @@ -54,13 +56,38 @@ def create_model_for_inference(compiled_path: str, model_path: str): 'batch_size': neuron_config_dict.get('batch_size', 1), 'seq_len': neuron_config_dict.get('seq_len', 128), 'torch_dtype': dtype, + 'save_sharded_checkpoint': neuron_config_dict.get('save_sharded_checkpoint', True), + 'on_cpu': neuron_config_dict.get('on_cpu', False), } + optional_params = ['world_size', 'max_context_length', 'enable_bucketing'] + for param in optional_params: + if param in neuron_config_dict: + neuron_config_kwargs[param] = neuron_config_dict[param] + + if 'max_context_length' not in neuron_config_kwargs: + neuron_config_kwargs['max_context_length'] = neuron_config_kwargs['seq_len'] + neuron_config = NeuronConfig(**neuron_config_kwargs) - # This will use the imported model and config classes - # The actual class names will be determined at runtime - return None, neuron_config + try: + model_config = GraniteInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config, + ) + except (TypeError, AttributeError): + model_config = GraniteInferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path), + ) + + try: + if hasattr(NeuronGraniteForCausalLM, 'from_pretrained'): + model = NeuronGraniteForCausalLM.from_pretrained(compiled_path, config=model_config) + else: + raise AttributeError("No from_pretrained method") + except (TypeError, AttributeError, Exception): + model = NeuronGraniteForCausalLM(model_path, model_config) + + return model, neuron_config def generate_with_neuron_model(model, input_ids, max_new_tokens: int): @@ -88,12 +115,34 @@ def generate_with_neuron_model(model, input_ids, max_new_tokens: int): return generated_ids + @pytest.fixture(scope="module") def compiled_model(): - """Load pre-compiled model.""" - # Note: Actual implementation would load the specific model class - # This is a template that should be customized per model - return None + """Compile and load model using our custom pattern.""" + compiled_path = Path(COMPILED_MODEL_PATH) + if not (compiled_path / "model.pt").exists(): + print(f"Compiling model to {COMPILED_MODEL_PATH}...") + + neuron_config = NeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=128, + max_context_length=128, + torch_dtype=torch.bfloat16, + ) + + config = GraniteInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + model = NeuronGraniteForCausalLM(MODEL_PATH, config) + model.compile(COMPILED_MODEL_PATH) + + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + + return model @pytest.fixture(scope="module") @@ -105,10 +154,17 @@ def tokenizer(): return tokenizer +@pytest.fixture(scope="module") +def generation_config(): + """Load generation config.""" + return GenerationConfig.from_pretrained(MODEL_PATH, do_sample=False, top_k=1, trust_remote_code=True) + + def test_model_loads(compiled_model): """Test that model loads successfully (smoke test).""" assert compiled_model is not None assert hasattr(compiled_model, 'config') + assert hasattr(compiled_model.config, 'neuron_config') print("✓ Smoke test passed - Model loaded successfully") @@ -127,14 +183,13 @@ def test_model_generates(compiled_model, tokenizer): def test_output_coherence(compiled_model, tokenizer): """Test that output is coherent (not gibberish).""" - prompt = "Hello, how are you?" + prompt = "What is 2 + 2?" inputs = tokenizer(prompt, return_tensors="pt", padding=True) generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=30) output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - # Coherence checks - assert len(output_text.split()) > 3, "Output should have multiple words" + assert len(output_text.split()) > 5, "Output should have multiple words" assert not _is_repetitive(output_text), "Output should not be repetitive" print(f"✓ Coherence test passed") @@ -142,31 +197,6 @@ def test_output_coherence(compiled_model, tokenizer): -def _is_repetitive(text: str, max_repeat: int = 5) -> bool: - """Check if text has excessive repetition.""" - words = text.split() - if len(words) < 10: - return False - - # Check for repeated words - for i in range(len(words) - max_repeat): - word = words[i] - if all(words[i+j] == word for j in range(max_repeat)): - return True - - # Check for repeated characters - new_text = text[-100:] if len(text) > 100 else text - if len(new_text) > 20: - char_counts = {} - for c in new_text: - char_counts[c] = char_counts.get(c, 0) + 1 - max_char_ratio = max(char_counts.values()) / len(new_text) - if max_char_ratio > 0.5: - return True - - return False - - def test_performance_ttft(compiled_model, tokenizer): """Test Time To First Token (TTFT) performance.""" import time @@ -193,11 +223,12 @@ def test_performance_ttft(compiled_model, tokenizer): _ = compiled_model(input_ids, position_ids=position_ids) end = time.perf_counter() - times.append((end - start) * 1000) # ms + times.append((end - start) * 1000) avg_ttft = sum(times) / len(times) - print(f"✓ TTFT: {avg_ttft:.2f}ms") - + + assert avg_ttft < 100, f"TTFT {avg_ttft:.2f}ms exceeds 100ms threshold" + print(f"✓ TTFT test passed: {avg_ttft:.2f}ms (threshold: 100ms)") def test_performance_throughput(compiled_model, tokenizer): @@ -219,18 +250,85 @@ def test_performance_throughput(compiled_model, tokenizer): total_time = end - start throughput = num_tokens / total_time - print(f"✓ Throughput: {throughput:.2f} tok/s") + + assert throughput > 10, f"Throughput {throughput:.2f} tok/s below 10 tok/s threshold" + print(f"✓ Throughput test passed: {throughput:.2f} tok/s (threshold: 10 tok/s)") + + +def _is_repetitive(text: str, max_repeat: int = 5) -> bool: + """Check if text has excessive repetition.""" + words = text.split() + if len(words) < 10: + return False + + for i in range(len(words) - max_repeat): + word = words[i] + if all(words[i+j] == word for j in range(max_repeat)): + return True + + return False if __name__ == "__main__": + # Run tests manually (without pytest) print("="*80) print("granite-3.1-8b-instruct Integration Tests") print("="*80) - print("\nNote: This is a template test file.") - print("For actual model testing, customize the model loading logic.") + # Setup - compile if needed + compiled_path = Path(COMPILED_MODEL_PATH) + if not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + + neuron_config = NeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=128, + max_context_length=128, + torch_dtype=torch.bfloat16, + ) + + config = GraniteInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + model = NeuronGraniteForCausalLM(MODEL_PATH, config) + model.compile(COMPILED_MODEL_PATH) + print("✓ Compilation complete") + + # Load model + print(f"\nLoading compiled model from {COMPILED_MODEL_PATH}...") + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + print("✓ Model loaded") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right", trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Run tests + print("\n" + "="*80) + print("Running Tests") + print("="*80) + + print("\n1. Smoke Test (Model Loading)...") + test_model_loads(model) + + print("\n2. Generation Test...") + test_model_generates(model, tokenizer) + + print("\n3. Coherence Test...") + test_output_coherence(model, tokenizer) + + print("\n4. TTFT Performance Test...") + test_performance_ttft(model, tokenizer) + + print("\n5. Throughput Performance Test...") + test_performance_throughput(model, tokenizer) print("\n" + "="*80) - print("✓ Template structure verified!") + print("✓ All tests passed!") print("="*80) From 8e4b2dd34bf0e1c155c512916935e139eab0cb47 Mon Sep 17 00:00:00 2001 From: Deeptanshu Singh Date: Thu, 26 Feb 2026 13:28:56 -0500 Subject: [PATCH 2/2] Removing internal names --- contrib/models/granite-3.1-8b-instruct/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/models/granite-3.1-8b-instruct/README.md b/contrib/models/granite-3.1-8b-instruct/README.md index 287f757..747631f 100644 --- a/contrib/models/granite-3.1-8b-instruct/README.md +++ b/contrib/models/granite-3.1-8b-instruct/README.md @@ -129,6 +129,6 @@ python3 test/integration/test_model.py ## Maintainer -Neuroboros Team - Annapurna Labs +Annapurna Labs **Last Updated:** 2026-02-06