Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e0a5429
initial testing
jlonge4 Oct 10, 2025
7eff4e9
replicate rmsnorm
jlonge4 Oct 15, 2025
a5f821d
replicate rmsnorm
jlonge4 Oct 15, 2025
16cd709
replicate rmsnorm
jlonge4 Oct 15, 2025
24e0dd7
add mermaid
jlonge4 Oct 15, 2025
0675233
Refactor tests to follow same pattern as TML's
jlonge4 Oct 27, 2025
09a1c29
Delete contributed/batch_invariance/kernels/rmsnorm_split_reduction.py
jlonge4 Oct 27, 2025
bf08add
Implement isa matmul version
jlonge4 Oct 29, 2025
1af87da
Enhance matmul and RMSNorm tests for correctness
jlonge4 Oct 29, 2025
a4814d0
Enhance RMSNorm kernel for batch variance demonstration
jlonge4 Oct 29, 2025
0f0b6f9
update readme
jlonge4 Oct 29, 2025
be7ff25
update readme
jlonge4 Oct 29, 2025
73419a7
Enhance RMSNorm kernel with improved indexing
jlonge4 Nov 4, 2025
3843cac
Optimize memory operations using nisa.dma_copy
jlonge4 Nov 4, 2025
34142ed
Optimize matmul with DMA copy for tile loading
jlonge4 Nov 4, 2025
31299db
Refactor RMSNorm tests for batch invariance and variance
jlonge4 Nov 4, 2025
4608fe8
Add isa and lang versions to demonstrate variance
jlonge4 Nov 4, 2025
89a1982
streamline readme
jlonge4 Nov 5, 2025
48ecf02
Revise README for NKI Batch Invariance Study
jlonge4 Jan 13, 2026
9224692
disambiguate testing
Jan 30, 2026
ec03e6c
disambiguate testing
Jan 30, 2026
a0cd1d4
Update to NKI2
jlonge4 Feb 25, 2026
9927d62
Update for NKI 2
jlonge4 Feb 25, 2026
2c26f50
NKI1 - NKI2
jlonge4 Feb 26, 2026
832a427
NKI1 -> NKI2
jlonge4 Feb 26, 2026
92b3014
NeuronSDK 2.28 - NKI2
jlonge4 Feb 27, 2026
e2eefa6
Delete contributed/batch_invariance/test_batch_invariance.py
jlonge4 Feb 27, 2026
d09a68a
Update for NeuronSDK2.28 - NKI2
jlonge4 Feb 27, 2026
649ea20
Revise README for clarity
jlonge4 Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions contributed/batch_invariance/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# NKI Batch Invariance Study

A study of batch invariance in Neuron Kernel Interface (NKI), replicating and extending [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) research.

## What is Batch Invariance?

Following [Thinking Machines' definition](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/):

**Batch invariance** requires:
1. **Run-to-run determinism**: Same prompt + same model + same inputs + same seed + same runtime config → bitwise-identical outputs across runs
2. **Batching independence**: Changing inference batching behavior (batch size, request packing, continuous batching order) → no output change

A batch-invariant system guarantees that the *way* you batch requests doesn't affect the numerical output—critical for reproducible LLM inference.

## Overview

This project demonstrates how different tile size configurations in NKI kernels can produce varying numerical results due to floating-point non-associativity. We test whether `nki.isa` operations maintain batch invariance when reduction tile sizes change—simulating what happens when a framework dynamically selects tile sizes based on input shape.

### Baselines Used

| Baseline Type | Purpose | Method |
|---------------|---------|--------|
| **CPU Reference** | Numerical parity | NKI kernel output vs PyTorch CPU (`torch.matmul`, manual RMSNorm) |
| **NKI Self-Baseline** | Run-to-run determinism | Same kernel, 1000 iterations, verify bitwise-identical outputs |
| **Tile Configuration Comparison** | Batching independence | Same kernel with different tile sizes (simulating shape-dependent selection) |

## Key Findings

### 1. Run-to-Run Determinism Confirmed

NKI ISA kernels produce bitwise-identical results across 1000 iterations with the same configuration.

### 2. Tile Size Invariance with `nki.isa`

**Critical finding**: `nki.isa` operations produce identical results regardless of tile size configuration in bfloat16 precision.

| Operation | K_TILE=128 vs K_TILE=64 | bfloat16 | float32 |
|-----------|-------------------------|----------|---------|
| **MatMul** | Tile invariant? | ✅ Yes (diff=0.0) | ✗ No (diff=6.1e-05) |
| **RMSNorm** | Tile invariant? | ✅ Yes (diff=0.0) | ✗ No (diff=2.4e-07) |

The bfloat16 invariance is the key result—reduced precision formats are where batch variance is most visible and problematic in practice, and ISA operations eliminate it entirely.

### 3. Historical Note: `nki.lang` Showed Variance

Prior to the NKI beta release, `nki.lang` operations exhibited tile-size-dependent variance:

| Operation | Kernel Type | float32 | bfloat16 | Amplification |
|-----------|-------------|---------|----------|---------------|
| **MatMul** | `nki.lang` | ✗ Variance (4.6e-05) | ✗ Variance (0.0078) | 170x |
| **RMSNorm** | `nki.lang` | ✗ Variance (3.6e-07) | ✗ Variance (0.0078) | 21,845x |

The bfloat16 amplification effect (errors 170-21,845x larger than float32) made variance highly visible in reduced precision formats. This behavior motivated the shift to `nki.isa` operations.

## How Tile Size Selection Can Break Batch Invariance

**The problem**: When reduction dimension tile sizes are selected based on input shape, the accumulation order changes. Due to floating-point non-associativity, different accumulation orders can produce different results:


(a + b) + c ≠ a + (b + c) in finite precision

**Triton Split-K (Shape-Dependent)**:
python
num_pid_k ← tl.cdiv(k, block_k × split_k) # Tile count varies with K dimension

**This Study's Simulation**:
Our kernels use a `deterministic` flag to compare two fixed tile configurations, simulating what happens when a framework chooses tile sizes based on input shape:

python
# MatMul kernel
if deterministic:
K_TILE = 128 # Fixed strategy
else:
K_TILE = 64 if K <= 512 else 512 # Shape-dependent strategy

# RMSNorm kernel
HIDDEN_TILE = 128 if deterministic else 64 # Different accumulation granularity

**Why this matters**: If an inference framework selects tile sizes based on batch dimensions, then changing batch size changes accumulation order—potentially breaking batch invariance even though each individual run is deterministic.

## Test Methodology

### What Each Test Validates

| Test | Validates | Method |
|------|-----------|--------|
| `test_determinism()` | Run-to-run determinism | Same config → identical results across 1000 runs |
| `test_tiling_invariance()` | Tile size independence | K_TILE=128 vs K_TILE=64 → same results? |
| `test_matmul_parity()` | Numerical correctness | NKI output matches `torch.matmul` |
| `test_rmsnorm_parity()` | Numerical correctness | NKI output matches PyTorch RMSNorm reference |

### Tile Size Variance Demonstration

python
# Compare deterministic=True (K_TILE=128) vs deterministic=False (K_TILE=64)
out_k128 = nki_matmul_kernel_isa(a, b, deterministic=True)
out_k64 = nki_matmul_kernel_isa(a, b, deterministic=False)

diff = (out_k128 - out_k64).abs().max().item()
# With nki.isa: diff == 0.0 (batch invariant)

## Running the Tests

bash
cd contributed/batch_invariance
python test_batch_invariance.py

### Expected Output

1. **Determinism test**: 1000 iterations produce identical results
2. **Parity tests**: NKI kernels match PyTorch reference within tolerance
3. **Tiling invariance**: Different tile sizes produce identical results (diff=0.0)

## Project Structure


batch_invariance/
├── README.md # This document
├── test_batch_invariance.py # Main test suite
└── kernels/
├── init.py
├── matmul_batch_invariant.py # MatMul ISA implementation
└── rmsnorm_batch_invariant.py # RMSNorm ISA implementation

## Implications for LLM Inference

### For Deterministic Inference
- **Use `nki.isa` operations** for batch-invariant kernels
- **bfloat16 precision** works reliably with ISA operations
- **Fixed tile sizes** avoid shape-dependent variance (though ISA tolerates variation)

### Why This Matters
Batch invariance ensures that:
- Changing batch size doesn't change model outputs
- Request packing order doesn't affect results
- Continuous batching produces reproducible inference
- Debugging and testing become tractable

## Future Work

1. **Batch Invariant Attention**: Implement attention mechanisms using ISA operations
2. **LLM Integration**: Full forward pass comparison with varying batch configurations
3. **Performance Analysis**: Quantify any performance trade-offs with ISA approach
4. **Extended Precision Study**: Investigate fp16, int8 behavior

## Core Insight

**Batch invariance requires that accumulation order doesn't affect the final result.**

Our tile size comparison (K_TILE=128 vs K_TILE=64) simulates shape-dependent tiling. The finding that `nki.isa` operations produce identical results regardless of tile configuration demonstrates a path to deterministic LLM inference on Neuron hardware—even when batching configurations change.

## References

- [Thinking Machines: Defeating Nondeterminism in LLM Inference](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/)
- [Thinking Machines GitHub: Batch Invariant Operations](https://github.com/thinking-machines-lab/batch_invariant_ops)
- [Meta: Triton Split-K Kernel Paper](https://scontent-dfw5-2.xx.fbcdn.net/v/t39.2365-6/418514147_782803483888724_2886980548537654804_n.pdf)
- [AWS Neuron Documentation](https://awsdocs-neuron.readthedocs-hosted.com/)
- [NKI Programming Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/)

## Author

Implementation and analysis by Josh Longenecker, based on foundational work by Thinking Machines Lab.
Empty file.
76 changes: 76 additions & 0 deletions contributed/batch_invariance/kernels/matmul_batch_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Batch-Invariant MatMul Kernel

This kernel demonstrates batch invariance in matrix multiplication by controlling
the M-dimension tiling strategy.
"""

import nki
import nki.isa as nisa
import nki.language as nl

@nki.jit
def nki_matmul_kernel_isa(a, b, deterministic=True):
"""
Matrix multiplication with batch invariance parameter

deterministic=True: Uses K_TILE=128
deterministic=False: Dynamic K_TILE size used

This demonstrates how different K tiling affects numerical results.
"""
K, M = a.shape
N = b.shape[1]
M_TILE = 128

# ONLY DIFFERENCE: K_TILE strategy
if deterministic:
K_TILE = 128 # Always hardcoded
else:
K_TILE = 64 if K <= 512 else 512 # Adaptive

result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm)

for m in nl.affine_range(M // M_TILE):
# Accumulator for this M chunk
c_psum = nl.ndarray((M_TILE, N), dtype=nl.float32, buffer=nl.psum)
# Reduction over K
for k in nl.affine_range(K // K_TILE):
# Allocate and load a: [K_TILE, M_TILE]
a_tile = nl.ndarray((K_TILE, M_TILE), dtype=a.dtype, buffer=nl.sbuf)
a_start = k*K_TILE
a_end = min(K, a_start + K_TILE)

m_start = m*M_TILE
m_end = min(M, m_start + M_TILE)

nisa.dma_copy(
src=a[a_start:a_end, m_start:m_end],
dst=a_tile,
)

# Allocate and load b: [K_TILE, N]
b_start = k*K_TILE
b_end = min(K, b_start + K_TILE)

b_tile = nl.ndarray((K_TILE, N), dtype=b.dtype, buffer=nl.sbuf)
nisa.dma_copy(
src=b[b_start:b_end, 0:N],
dst=b_tile,
)
# Matmul
nisa.nc_matmul(dst=c_psum, stationary=a_tile, moving=b_tile)
# c_psum += nisa.nc_matmul(a_tile, b_tile)

# Store this M chunk
c_sbuf = nl.ndarray((M_TILE, N), dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=c_sbuf, src=c_psum)

c_start = m*M_TILE
c_end = min(M, c_start + M_TILE)
nisa.dma_copy(
src=c_sbuf,
dst=result[c_start:c_end, 0:N]
)

return result
98 changes: 98 additions & 0 deletions contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import math
import nki
import nki.isa as nisa
import nki.language as nl


@nki.jit
def nki_rmsnorm_kernel_isa(a, g, deterministic=True):
out_tensor = nl.ndarray(a.shape, dtype=a.dtype, buffer=nl.shared_hbm)

num_rows, hidden_dim = a.shape[0], a.shape[1]
BATCH_TILE = 128
HIDDEN_TILE = 128 if deterministic else 64

g = g.reshape((1, hidden_dim))

ones_vec = nl.ndarray((1, BATCH_TILE), dtype=nl.float32, buffer=nl.sbuf)
nisa.memset(dst=ones_vec, value=1.0)

zero_bias = nl.ndarray((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.memset(dst=zero_bias, value=0.0)

for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)):
b_start = i * BATCH_TILE
b_end = min(num_rows, b_start + BATCH_TILE)

sum_sq = nl.ndarray((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.memset(dst=sum_sq, value=0.0)

# Pass 1: Compute sum of squares
for h in nl.affine_range(math.ceil(hidden_dim / HIDDEN_TILE)):
h_start = h * HIDDEN_TILE
h_end = min(hidden_dim, h_start + HIDDEN_TILE)

x = nl.ndarray((BATCH_TILE, HIDDEN_TILE), dtype=a.dtype, buffer=nl.sbuf)
nisa.dma_copy(
dst=x, src=a[b_start:b_end, h_start:h_end]
)

x_sq = nl.ndarray(
(BATCH_TILE, HIDDEN_TILE), dtype=nl.float32, buffer=nl.sbuf
)
tile_sum = nl.ndarray((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.activation_reduce(
dst=x_sq,
op=nl.square,
data=x,
reduce_op=nl.add,
reduce_res=tile_sum,
bias=zero_bias,
scale=1.0,
)

nisa.tensor_tensor(dst=sum_sq, data1=sum_sq, data2=tile_sum, op=nl.add)

rms_inv = nl.ndarray((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.activation(
dst=rms_inv,
op=nl.rsqrt,
data=sum_sq,
scale=1.0 / hidden_dim,
bias=zero_bias,
)

# Pass 2: Normalize and apply weight
for h in nl.affine_range(math.ceil(hidden_dim / HIDDEN_TILE)):
h_start = h * HIDDEN_TILE
h_end = min(hidden_dim, h_start + HIDDEN_TILE)

x = nl.ndarray((BATCH_TILE, HIDDEN_TILE), dtype=a.dtype, buffer=nl.sbuf)
nisa.dma_copy(
dst=x, src=a[b_start:b_end, h_start:h_end]
)

g_tile = nl.ndarray((1, HIDDEN_TILE), dtype=nl.float32, buffer=nl.sbuf)
nisa.dma_copy(dst=g_tile, src=g[0:1, h_start:h_end])

g_bcast = nl.ndarray(
(BATCH_TILE, HIDDEN_TILE), dtype=nl.float32, buffer=nl.psum
)
nisa.nc_matmul(dst=g_bcast, stationary=ones_vec, moving=g_tile)

x_out = nl.ndarray((BATCH_TILE, HIDDEN_TILE), dtype=a.dtype, buffer=nl.sbuf)
nisa.scalar_tensor_tensor(
dst=x_out,
data=x,
op0=nl.multiply,
operand0=rms_inv,
op1=nl.multiply,
operand1=g_bcast,
)

nisa.dma_copy(
dst=out_tensor[b_start:b_end, h_start:h_end],
src=x_out,
)

return out_tensor
Loading
Loading