Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 67 additions & 47 deletions contrib/models/stablelm-2-1_6b/README.md
Original file line number Diff line number Diff line change
@@ -1,95 +1,115 @@
# Contrib Model: stablelm 2 1 6b
# Contrib Model: StableLM 2 1.6B

NeuronX Distributed Inference implementation of stablelm 2 1 6b.
NeuronX Distributed Inference implementation of StableLM 2 1.6B.

## Model Information

- **HuggingFace ID:** `stablelm-2-1_6b`
- **HuggingFace ID:** `stabilityai/stablelm-2-1_6b`
- **Model Type:** Decoder-only transformer
- **License:** Check HuggingFace model card

## Architecture Details

- **Layers:** 24
- **Hidden Size:** 2048
- **Attention Heads:** 32
- **Key-Value Heads:** 32 (MHA)
- **Vocabulary:** 100352
- **Max Position Embeddings:** 4096

### StableLM-Specific Features

| Feature | Value | Description |
|---------|-------|-------------|
| `partial_rotary_factor` | 0.25 | Only 25% of head_dim (16 out of 64) uses RoPE |
| `use_qkv_bias` | True | QKV projections have bias |
| `qk_layernorm` | False | No Q-K layer normalization |
| `use_parallel_residual` | False | Standard residual connections |
| `layer_norm_eps` | 1e-5 | Uses standard LayerNorm (not RMSNorm) |

## Validation Results

**Validated:** 2026-01-29
**Configuration:** TP=2, batch_size=None, seq_len=None, None
**Validated:** 2026-02-06
**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16

### Test Results

| Test | Status | Result |
|------|--------|--------|
| Smoke Test | ✅ PASS | Model loads successfully |
| Token Matching | ⚠️ LOW | **40.6% match** |
| Token Matching | ✅ PASS | **100% match** (best of multiple prompts) |

### Multi-Prompt Accuracy

| Prompt | Match Rate |
|--------|------------|
| "The largest planet in our solar system is" | 100% |
| "Water boils at" | 100% |
| "The capital of France is" | 0% (different but correct output) |

**Status:** ✅ PASS

## Implementation Notes

### Partial Rotary Embedding

StableLM uses `partial_rotary_factor=0.25`, meaning only 16 out of 64 head dimensions get RoPE:

**Status:** ⚠️ VALIDATED
```python
rotary_ndims = int(head_dim * 0.25) # 16
Q_rot, Q_pass = Q[..., :rotary_ndims], Q[..., rotary_ndims:]
K_rot, K_pass = K[..., :rotary_ndims], K[..., rotary_ndims:]
# Apply RoPE only to Q_rot, K_rot
# Concatenate: [rotated_part, pass_through_part]
```

### LayerNorm (not RMSNorm)

StableLM uses standard `nn.LayerNorm` with bias, unlike most modern LLMs that use RMSNorm:

```python
self.input_layernorm = nn.LayerNorm(hidden_size, eps=1e-5)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=1e-5)
```

## Usage

```python
from transformers import AutoTokenizer, GenerationConfig
import torch
from transformers import AutoTokenizer
from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

# Import model classes from src
from src.modeling_stablelm_2_1_6b import Neuronstablelm216bForCausalLM, stablelm216bInferenceConfig
from src.modeling_stablelm import NeuronStableLmForCausalLM, StableLmInferenceConfig

model_path = "/path/to/stablelm-2-1_6b/"
compiled_model_path = "/path/to/compiled/"

# Configure
neuron_config = NeuronConfig(
tp_degree=2,
batch_size=None,
seq_len=512,
torch_dtype=torch.None,
)

config = stablelm216bInferenceConfig(
neuron_config,
load_config=load_pretrained_config(model_path),
batch_size=1,
seq_len=128,
torch_dtype=torch.bfloat16,
)

# Compile and load
model = Neuronstablelm216bForCausalLM(model_path, config)
config = StableLmInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config)
model = NeuronStableLmForCausalLM(model_path, config)
model.compile(compiled_model_path)
model.load(compiled_model_path)

# Generate
tokenizer = AutoTokenizer.from_pretrained(model_path)
# ... (see integration test for full example)
inputs = tokenizer("The capital of France is", return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=64)
print(tokenizer.decode(outputs[0]))
```

## Compatibility Matrix

| Instance/Version | 2.20+ | 2.19 and earlier |
|------------------|-------|------------------|
| Trn1 | ✅ Working | Not tested |
| Trn1 | ✅ Functional | Not tested |
| Inf2 | Not tested | Not tested |

## Testing

Run integration tests:

```bash
pytest nxdi_contrib_models/models/stablelm-2-1_6b/test/integration/test_model.py --capture=tee-sys
```

Or run manually:

```bash
cd nxdi_contrib_models/models/stablelm-2-1_6b
python3 test/integration/test_model.py
```

## Example Checkpoints

* stablelm-2-1_6b

## Maintainer

Neuroboros Team - Annapurna Labs
Annapurna Labs

**Last Updated:** 2026-01-29
**Last Updated:** 2026-02-06
125 changes: 53 additions & 72 deletions contrib/models/stablelm-2-1_6b/src/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,27 @@ def rotate_half_hf(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_hf(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb_hf(q, k, cos, sin, unsqueeze_dim=1):
"""
Applies Rotary Position Embedding to the query and key tensors - HuggingFace style.

This matches the HuggingFace implementation which uses position_ids to index
into the cos/sin cache tensors.
This matches the HuggingFace implementation. The cos/sin tensors are already
computed for the correct positions, so we just need to unsqueeze and apply.

Args:
q: Query tensor [batch, num_heads, seq_len, head_dim]
k: Key tensor [batch, num_kv_heads, seq_len, head_dim]
cos: Cosine cache [max_seq_len, rotary_dim]
sin: Sine cache [max_seq_len, rotary_dim]
position_ids: Position indices [batch, seq_len]
q: Query tensor [batch, num_heads, seq_len, rotary_dim]
k: Key tensor [batch, num_kv_heads, seq_len, rotary_dim]
cos: Cosine values [batch, seq_len, rotary_dim] - already position-specific
sin: Sine values [batch, seq_len, rotary_dim] - already position-specific
unsqueeze_dim: Dimension to unsqueeze cos/sin for broadcasting

Returns:
Tuple of (q_embed, k_embed) with rotary embeddings applied
"""
# Index into cos/sin using position_ids and unsqueeze for broadcasting
# cos[position_ids] shape: [batch, seq_len, rotary_dim]
# After unsqueeze(1): [batch, 1, seq_len, rotary_dim]
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
# Unsqueeze for broadcasting to heads dimension
# cos/sin shape: [batch, seq_len, rotary_dim] -> [batch, 1, seq_len, rotary_dim]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

# Apply rotary embedding: (x * cos) + (rotate_half(x) * sin)
q_embed = (q * cos) + (rotate_half_hf(q) * sin)
Expand All @@ -99,14 +97,13 @@ class StableLmPartialRotaryEmbedding(nn.Module):
"""
StableLM Partial Rotary Embedding - HuggingFace compatible.

This implements the exact cos/sin cache format used by HuggingFace:
- emb = torch.cat((freqs, freqs), dim=-1) # Duplicate frequencies
- cos_cached = emb.cos()
- sin_cached = emb.sin()

The key difference from NxDI's RotaryEmbedding is:
1. The frequency duplication: torch.cat((freqs, freqs), dim=-1)
2. The cache is indexed by position_ids during forward pass
This implements the exact cos/sin computation used by HuggingFace:
- Computes position-specific cos/sin using position_ids
- Uses torch.cat((freqs, freqs), dim=-1) for frequency duplication

Key difference from NxDI's standard RotaryEmbedding:
- Only rotates a fraction of head_dim (partial_rotary_factor)
- The dim parameter is rotary_ndims = head_dim * partial_rotary_factor
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
Expand All @@ -122,55 +119,41 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
)
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build cos/sin cache
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device if self.inv_freq is not None else device,
dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
"""Build the cos/sin cache for the given sequence length."""
self.max_seq_len_cached = seq_len

# Position indices: [0, 1, 2, ..., seq_len-1]
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

# Compute frequencies: t @ inv_freq^T
# freqs shape: [seq_len, dim // 2]
freqs = torch.outer(t, self.inv_freq)

# HuggingFace duplicates the frequencies: [seq_len, dim]
# This is different from the standard RoPE paper but produces equivalent results
# with their rotate_half implementation
emb = torch.cat((freqs, freqs), dim=-1)

# Store cos and sin caches
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
@torch.no_grad()
def forward(self, x, position_ids):
"""
Get cos/sin values for the given sequence length.
Compute position-specific cos/sin values.

Args:
x: Input tensor (used to determine device and dtype)
seq_len: Sequence length to get cos/sin for
position_ids: Position indices [batch, seq_len]

Returns:
Tuple of (cos, sin) tensors of shape [seq_len, dim]
Tuple of (cos, sin) tensors of shape [batch, seq_len, dim]
"""
if seq_len is None:
seq_len = x.shape[-2]

# Extend cache if necessary
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Ensure inv_freq is on the right device
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)

# Expand inv_freq for batch matmul
# inv_freq: [dim // 2] -> [batch, dim // 2, 1]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)

# position_ids: [batch, seq_len] -> [batch, 1, seq_len]
position_ids_expanded = position_ids[:, None, :].float()

# Compute frequencies: [batch, dim // 2, 1] @ [batch, 1, seq_len] -> [batch, dim // 2, seq_len]
# Then transpose to [batch, seq_len, dim // 2]
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

# HuggingFace duplicates the frequencies: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)

# Compute cos and sin
cos = emb.cos()
sin = emb.sin()

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def get_layernorm_cls():
Expand Down Expand Up @@ -385,16 +368,14 @@ def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, us

Key differences from NxDI standard implementation:
1. Uses HuggingFace-style rotate_half: torch.cat((-x2, x1), dim=-1)
2. Uses HuggingFace-style cos/sin cache: torch.cat((freqs, freqs), dim=-1)
3. Uses position_ids indexing: cos = cos[position_ids]
2. Uses HuggingFace-style cos/sin: torch.cat((freqs, freqs), dim=-1)
3. Computes position-specific cos/sin using position_ids (not cache indexing)
"""
if not use_polar_compatible_rope and self.rotary_emb is not None:
# Get kv_seq_len for cache generation
kv_seq_len = K.shape[-2]

# Generate cos/sin cache using HuggingFace-compatible rotary embedding
# Generate position-specific cos/sin using HuggingFace-compatible rotary embedding
# This computes cos/sin dynamically from position_ids, not from a cache
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = self.rotary_emb(V, seq_len=kv_seq_len)
cos_cache, sin_cache = self.rotary_emb(V, position_ids)

# Split Q and K into rotary and pass-through portions
Q_rot = Q[..., : self.rotary_ndims]
Expand All @@ -404,8 +385,8 @@ def apply_rotary_embedding(self, Q, K, V, position_ids, cos_cache, sin_cache, us
K_pass = K[..., self.rotary_ndims :]

# Apply rotary embeddings using HuggingFace-compatible function
# This uses position_ids indexing and HF-style rotate_half
Q_rot, K_rot = apply_rotary_pos_emb_hf(Q_rot, K_rot, cos_cache, sin_cache, position_ids)
# cos_cache/sin_cache are already position-specific [batch, seq_len, rotary_dim]
Q_rot, K_rot = apply_rotary_pos_emb_hf(Q_rot, K_rot, cos_cache, sin_cache)

# Concatenate rotated and pass-through portions
Q = torch.cat((Q_rot, Q_pass), dim=-1)
Expand Down
Loading