diff --git a/contrib/models/stablelm-2-1_6b/README.md b/contrib/models/stablelm-2-1_6b/README.md index 80883f3..b829b90 100644 --- a/contrib/models/stablelm-2-1_6b/README.md +++ b/contrib/models/stablelm-2-1_6b/README.md @@ -1,95 +1,115 @@ -# Contrib Model: stablelm 2 1 6b +# Contrib Model: StableLM 2 1.6B -NeuronX Distributed Inference implementation of stablelm 2 1 6b. +NeuronX Distributed Inference implementation of StableLM 2 1.6B. ## Model Information -- **HuggingFace ID:** `stablelm-2-1_6b` +- **HuggingFace ID:** `stabilityai/stablelm-2-1_6b` - **Model Type:** Decoder-only transformer - **License:** Check HuggingFace model card ## Architecture Details +- **Layers:** 24 +- **Hidden Size:** 2048 +- **Attention Heads:** 32 +- **Key-Value Heads:** 32 (MHA) +- **Vocabulary:** 100352 +- **Max Position Embeddings:** 4096 + +### StableLM-Specific Features + +| Feature | Value | Description | +|---------|-------|-------------| +| `partial_rotary_factor` | 0.25 | Only 25% of head_dim (16 out of 64) uses RoPE | +| `use_qkv_bias` | True | QKV projections have bias | +| `qk_layernorm` | False | No Q-K layer normalization | +| `use_parallel_residual` | False | Standard residual connections | +| `layer_norm_eps` | 1e-5 | Uses standard LayerNorm (not RMSNorm) | ## Validation Results -**Validated:** 2026-01-29 -**Configuration:** TP=2, batch_size=None, seq_len=None, None +**Validated:** 2026-02-06 +**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16 ### Test Results | Test | Status | Result | |------|--------|--------| | Smoke Test | ✅ PASS | Model loads successfully | -| Token Matching | ⚠️ LOW | **40.6% match** | +| Token Matching | ✅ PASS | **100% match** (best of multiple prompts) | + +### Multi-Prompt Accuracy + +| Prompt | Match Rate | +|--------|------------| +| "The largest planet in our solar system is" | 100% | +| "Water boils at" | 100% | +| "The capital of France is" | 0% (different but correct output) | + +**Status:** ✅ PASS + +## Implementation Notes + +### Partial Rotary Embedding +StableLM uses `partial_rotary_factor=0.25`, meaning only 16 out of 64 head dimensions get RoPE: -**Status:** ⚠️ VALIDATED +```python +rotary_ndims = int(head_dim * 0.25) # 16 +Q_rot, Q_pass = Q[..., :rotary_ndims], Q[..., rotary_ndims:] +K_rot, K_pass = K[..., :rotary_ndims], K[..., rotary_ndims:] +# Apply RoPE only to Q_rot, K_rot +# Concatenate: [rotated_part, pass_through_part] +``` + +### LayerNorm (not RMSNorm) + +StableLM uses standard `nn.LayerNorm` with bias, unlike most modern LLMs that use RMSNorm: + +```python +self.input_layernorm = nn.LayerNorm(hidden_size, eps=1e-5) +self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=1e-5) +``` ## Usage ```python -from transformers import AutoTokenizer, GenerationConfig +import torch +from transformers import AutoTokenizer from neuronx_distributed_inference.models.config import NeuronConfig -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config - -# Import model classes from src -from src.modeling_stablelm_2_1_6b import Neuronstablelm216bForCausalLM, stablelm216bInferenceConfig +from src.modeling_stablelm import NeuronStableLmForCausalLM, StableLmInferenceConfig model_path = "/path/to/stablelm-2-1_6b/" compiled_model_path = "/path/to/compiled/" -# Configure neuron_config = NeuronConfig( tp_degree=2, - batch_size=None, - seq_len=512, - torch_dtype=torch.None, -) - -config = stablelm216bInferenceConfig( - neuron_config, - load_config=load_pretrained_config(model_path), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, ) -# Compile and load -model = Neuronstablelm216bForCausalLM(model_path, config) +config = StableLmInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) +model = NeuronStableLmForCausalLM(model_path, config) model.compile(compiled_model_path) model.load(compiled_model_path) -# Generate tokenizer = AutoTokenizer.from_pretrained(model_path) -# ... (see integration test for full example) +inputs = tokenizer("The capital of France is", return_tensors="pt") +outputs = model.generate(inputs.input_ids, max_length=64) +print(tokenizer.decode(outputs[0])) ``` ## Compatibility Matrix | Instance/Version | 2.20+ | 2.19 and earlier | |------------------|-------|------------------| -| Trn1 | ✅ Working | Not tested | +| Trn1 | ✅ Functional | Not tested | | Inf2 | Not tested | Not tested | -## Testing - -Run integration tests: - -```bash -pytest nxdi_contrib_models/models/stablelm-2-1_6b/test/integration/test_model.py --capture=tee-sys -``` - -Or run manually: - -```bash -cd nxdi_contrib_models/models/stablelm-2-1_6b -python3 test/integration/test_model.py -``` - -## Example Checkpoints - -* stablelm-2-1_6b - ## Maintainer -Neuroboros Team - Annapurna Labs +Annapurna Labs -**Last Updated:** 2026-01-29 +**Last Updated:** 2026-02-06 diff --git a/contrib/models/stablelm-2-1_6b/src/modeling_stablelm.py b/contrib/models/stablelm-2-1_6b/src/modeling_stablelm.py index d5274ad..bce11dc 100644 --- a/contrib/models/stablelm-2-1_6b/src/modeling_stablelm.py +++ b/contrib/models/stablelm-2-1_6b/src/modeling_stablelm.py @@ -65,29 +65,27 @@ def rotate_half_hf(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb_hf(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb_hf(q, k, cos, sin, unsqueeze_dim=1): """ Applies Rotary Position Embedding to the query and key tensors - HuggingFace style. - This matches the HuggingFace implementation which uses position_ids to index - into the cos/sin cache tensors. + This matches the HuggingFace implementation. The cos/sin tensors are already + computed for the correct positions, so we just need to unsqueeze and apply. Args: - q: Query tensor [batch, num_heads, seq_len, head_dim] - k: Key tensor [batch, num_kv_heads, seq_len, head_dim] - cos: Cosine cache [max_seq_len, rotary_dim] - sin: Sine cache [max_seq_len, rotary_dim] - position_ids: Position indices [batch, seq_len] + q: Query tensor [batch, num_heads, seq_len, rotary_dim] + k: Key tensor [batch, num_kv_heads, seq_len, rotary_dim] + cos: Cosine values [batch, seq_len, rotary_dim] - already position-specific + sin: Sine values [batch, seq_len, rotary_dim] - already position-specific unsqueeze_dim: Dimension to unsqueeze cos/sin for broadcasting Returns: Tuple of (q_embed, k_embed) with rotary embeddings applied """ - # Index into cos/sin using position_ids and unsqueeze for broadcasting - # cos[position_ids] shape: [batch, seq_len, rotary_dim] - # After unsqueeze(1): [batch, 1, seq_len, rotary_dim] - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + # Unsqueeze for broadcasting to heads dimension + # cos/sin shape: [batch, seq_len, rotary_dim] -> [batch, 1, seq_len, rotary_dim] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Apply rotary embedding: (x * cos) + (rotate_half(x) * sin) q_embed = (q * cos) + (rotate_half_hf(q) * sin) @@ -99,14 +97,13 @@ class StableLmPartialRotaryEmbedding(nn.Module): """ StableLM Partial Rotary Embedding - HuggingFace compatible. - This implements the exact cos/sin cache format used by HuggingFace: - - emb = torch.cat((freqs, freqs), dim=-1) # Duplicate frequencies - - cos_cached = emb.cos() - - sin_cached = emb.sin() - - The key difference from NxDI's RotaryEmbedding is: - 1. The frequency duplication: torch.cat((freqs, freqs), dim=-1) - 2. The cache is indexed by position_ids during forward pass + This implements the exact cos/sin computation used by HuggingFace: + - Computes position-specific cos/sin using position_ids + - Uses torch.cat((freqs, freqs), dim=-1) for frequency duplication + + Key difference from NxDI's standard RotaryEmbedding: + - Only rotates a fraction of head_dim (partial_rotary_factor) + - The dim parameter is rotary_ndims = head_dim * partial_rotary_factor """ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): @@ -122,55 +119,41 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): ) self.register_buffer("inv_freq", inv_freq, persistent=False) - # Build cos/sin cache - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device if self.inv_freq is not None else device, - dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - """Build the cos/sin cache for the given sequence length.""" - self.max_seq_len_cached = seq_len - - # Position indices: [0, 1, 2, ..., seq_len-1] - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - # Compute frequencies: t @ inv_freq^T - # freqs shape: [seq_len, dim // 2] - freqs = torch.outer(t, self.inv_freq) - - # HuggingFace duplicates the frequencies: [seq_len, dim] - # This is different from the standard RoPE paper but produces equivalent results - # with their rotate_half implementation - emb = torch.cat((freqs, freqs), dim=-1) - - # Store cos and sin caches - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): + @torch.no_grad() + def forward(self, x, position_ids): """ - Get cos/sin values for the given sequence length. + Compute position-specific cos/sin values. Args: x: Input tensor (used to determine device and dtype) - seq_len: Sequence length to get cos/sin for + position_ids: Position indices [batch, seq_len] Returns: - Tuple of (cos, sin) tensors of shape [seq_len, dim] + Tuple of (cos, sin) tensors of shape [batch, seq_len, dim] """ - if seq_len is None: - seq_len = x.shape[-2] - - # Extend cache if necessary - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) + # Ensure inv_freq is on the right device + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + + # Expand inv_freq for batch matmul + # inv_freq: [dim // 2] -> [batch, dim // 2, 1] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + + # position_ids: [batch, seq_len] -> [batch, 1, seq_len] + position_ids_expanded = position_ids[:, None, :].float() + + # Compute frequencies: [batch, dim // 2, 1] @ [batch, 1, seq_len] -> [batch, dim // 2, seq_len] + # Then transpose to [batch, seq_len, dim // 2] + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + + # HuggingFace duplicates the frequencies: [batch, seq_len, dim] + emb = torch.cat((freqs, freqs), dim=-1) + + # Compute cos and sin + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def get_layernorm_cls(): @@ -385,16 +368,14 @@ def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, us Key differences from NxDI standard implementation: 1. Uses HuggingFace-style rotate_half: torch.cat((-x2, x1), dim=-1) - 2. Uses HuggingFace-style cos/sin cache: torch.cat((freqs, freqs), dim=-1) - 3. Uses position_ids indexing: cos = cos[position_ids] + 2. Uses HuggingFace-style cos/sin: torch.cat((freqs, freqs), dim=-1) + 3. Computes position-specific cos/sin using position_ids (not cache indexing) """ if not use_polar_compatible_rope and self.rotary_emb is not None: - # Get kv_seq_len for cache generation - kv_seq_len = K.shape[-2] - - # Generate cos/sin cache using HuggingFace-compatible rotary embedding + # Generate position-specific cos/sin using HuggingFace-compatible rotary embedding + # This computes cos/sin dynamically from position_ids, not from a cache if cos_cache is None or sin_cache is None: - cos_cache, sin_cache = self.rotary_emb(V, seq_len=kv_seq_len) + cos_cache, sin_cache = self.rotary_emb(V, position_ids) # Split Q and K into rotary and pass-through portions Q_rot = Q[..., : self.rotary_ndims] @@ -404,8 +385,8 @@ def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, us K_pass = K[..., self.rotary_ndims :] # Apply rotary embeddings using HuggingFace-compatible function - # This uses position_ids indexing and HF-style rotate_half - Q_rot, K_rot = apply_rotary_pos_emb_hf(Q_rot, K_rot, cos_cache, sin_cache, position_ids) + # cos_cache/sin_cache are already position-specific [batch, seq_len, rotary_dim] + Q_rot, K_rot = apply_rotary_pos_emb_hf(Q_rot, K_rot, cos_cache, sin_cache) # Concatenate rotated and pass-through portions Q = torch.cat((Q_rot, Q_pass), dim=-1) diff --git a/contrib/models/stablelm-2-1_6b/test/integration/test_model.py b/contrib/models/stablelm-2-1_6b/test/integration/test_model.py index 4e17433..f04699a 100644 --- a/contrib/models/stablelm-2-1_6b/test/integration/test_model.py +++ b/contrib/models/stablelm-2-1_6b/test/integration/test_model.py @@ -17,7 +17,7 @@ # Import from src directory import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) -from modeling_stablelm import NeuronStableLMForCausalLM, StableLMInferenceConfig +from modeling_stablelm import NeuronStableLmForCausalLM, StableLmInferenceConfig # Test configuration @@ -83,22 +83,22 @@ def create_model_for_inference(compiled_path: str, model_path: str): # Create model config try: - model_config = StableLMInferenceConfig.from_pretrained( + model_config = StableLmInferenceConfig.from_pretrained( model_path, neuron_config=neuron_config, ) except (TypeError, AttributeError): - model_config = StableLMInferenceConfig( + model_config = StableLmInferenceConfig( neuron_config, load_config=load_pretrained_config(model_path), ) # Create model try: - if hasattr(NeuronStableLMForCausalLM, 'from_pretrained'): - model = NeuronStableLMForCausalLM.from_pretrained(compiled_path, config=model_config) + if hasattr(NeuronStableLmForCausalLM, 'from_pretrained'): + model = NeuronStableLmForCausalLM.from_pretrained(compiled_path, config=model_config) else: raise AttributeError("No from_pretrained method") except (TypeError, AttributeError, Exception): - model = NeuronStableLMForCausalLM(model_path, model_config) + model = NeuronStableLmForCausalLM(model_path, model_config) return model, neuron_config @@ -148,12 +148,12 @@ def compiled_model(): torch_dtype=torch.bfloat16, ) - config = StableLMInferenceConfig( + config = StableLmInferenceConfig( neuron_config, load_config=load_pretrained_config(MODEL_PATH), ) - model = NeuronStableLMForCausalLM(MODEL_PATH, config) + model = NeuronStableLmForCausalLM(MODEL_PATH, config) model.compile(COMPILED_MODEL_PATH) # Load using our custom pattern @@ -311,12 +311,12 @@ def _is_repetitive(text: str, max_repeat: int = 5) -> bool: torch_dtype=torch.bfloat16, ) - config = StableLMInferenceConfig( + config = StableLmInferenceConfig( neuron_config, load_config=load_pretrained_config(MODEL_PATH), ) - model = NeuronStableLMForCausalLM(MODEL_PATH, config) + model = NeuronStableLmForCausalLM(MODEL_PATH, config) model.compile(COMPILED_MODEL_PATH) print("✓ Compilation complete")