Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions contrib/models/granite-3.1-8b-instruct/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,69 @@ NeuronX Distributed Inference implementation of granite 3.1 8b instruct.

## Architecture Details

- **Layers:** Check model config
- **Hidden Size:** Check model config
- **Attention Heads:** Check model config
- **Vocabulary:** Check model config
- **Max Position Embeddings:** Check model config
- **Layers:** 32
- **Hidden Size:** 4096
- **Attention Heads:** 32
- **Key-Value Heads:** 8 (GQA)
- **Vocabulary:** 49152
- **Max Position Embeddings:** 131072

### Granite-Specific Scaling Factors

Granite uses custom scaling factors that differ from standard Llama:

| Parameter | Value | Description |
|-----------|-------|-------------|
| `embedding_multiplier` | 12.0 | Scales input embeddings after lookup |
| `attention_multiplier` | 0.0078125 (1/head_dim) | Custom attention scaling instead of 1/√head_dim |
| `residual_multiplier` | 0.22 | Scales residual connections |
| `logits_scaling` | 16.0 | Divides output logits |

## Validation Results

**Validated:** 2026-01-29
**Validated:** 2026-02-06
**Configuration:** TP=2, batch_size=1, seq_len=128, bfloat16

### Test Results

| Test | Status | Result |
|------|--------|--------|
| Smoke Test | ✅ PASS | Model loads successfully |
| Token Matching | ⚠️ LOW | **7.8% match** |
| TTFT (P50) | ✅ PASS | 19.44ms (threshold: 100ms) |
| Throughput | ✅ PASS | 106.00 tok/s (threshold: 10 tok/s) |
| Token Matching | ✅ PASS | **100% match** (64/64 tokens) |
| TTFT (P50) | ✅ PASS | ~20ms (threshold: 100ms) |
| Throughput | ✅ PASS | ~100 tok/s (threshold: 10 tok/s) |

### Performance Metrics

| Metric | Value |
|--------|-------|
| TTFT (P50) | 19.44ms |
| Throughput | 106.00 tokens/s |

| TTFT (P50) | ~20ms |
| Throughput | ~100 tokens/s |

**Status:** ✅ VALIDATED

## Critical Implementation Notes

This implementation includes critical fixes for Granite's custom scaling:

1. **Attention Multiplier Fix**: The `prep_qkv_tensors` method in `NeuronGraniteAttention` applies a correction factor to Q tensors to convert from the standard `1/√head_dim` scaling to Granite's `attention_multiplier`.

2. **Embedding Multiplier**: Applied in `get_model_output` after embedding lookup (not to weights, to handle tied embeddings correctly).

3. **Logits Scaling**: Applied via `ScaledColumnParallelLinear` which divides output by `logits_scaling`.

4. **Residual Multiplier**: Applied in `NeuronGraniteDecoderLayer` to scale residual connections.

## Usage

```python
from transformers import AutoTokenizer, GenerationConfig
import torch
from transformers import AutoTokenizer
from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config

# Import model classes from src
from src.modeling_granite_3_1_8b_instruct import Neurongranite318binstructForCausalLM, granite318binstructInferenceConfig
from src.modeling_granite import NeuronGraniteForCausalLM, GraniteInferenceConfig

model_path = "/path/to/granite-3.1-8b-instruct/"
compiled_model_path = "/path/to/compiled/"
Expand All @@ -61,18 +85,19 @@ neuron_config = NeuronConfig(
torch_dtype=torch.bfloat16,
)

config = granite318binstructInferenceConfig(
config = GraniteInferenceConfig(
neuron_config,
load_config=load_pretrained_config(model_path),
)

# Compile and load
model = Neurongranite318binstructForCausalLM(model_path, config)
model = NeuronGraniteForCausalLM(model_path, config)
model.compile(compiled_model_path)
model.load(compiled_model_path)

# Generate
tokenizer = AutoTokenizer.from_pretrained(model_path)
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
# ... (see integration test for full example)
```

Expand Down Expand Up @@ -104,6 +129,6 @@ python3 test/integration/test_model.py

## Maintainer

Neuroboros Team - Annapurna Labs
Annapurna Labs

**Last Updated:** 2026-01-29
**Last Updated:** 2026-02-06
Loading