Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/chop/nn/quantized/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
bmm_minifloat_ieee,
bmm_binary,
bmm_ternary,
bmm_mxfp,
bmm_mxint,
matmul_block_fp,
matmul_block_log,
matmul_block_minifloat,
Expand All @@ -51,7 +53,14 @@
matmul_minifloat_ieee,
matmul_binary,
matmul_ternary,
matmul_mxfp,
matmul_mxint,
)
from .softmax import softmax_mxfp, softmax_mxint, softmax_minifloat
from .silu import silu_mxfp, silu_mxint, silu_minifloat
from .rope import rope_mxfp, rope_mxint, rope_minifloat
from .kvcache import kv_cache_mxfp, kv_cache_mxint

from .mult import (
mult_block_fp,
mult_block_log,
Expand Down Expand Up @@ -203,6 +212,10 @@
"matmul_block_log": matmul_block_log,
"matmul_binary": matmul_binary,
"matmul_ternary": matmul_ternary,
"matmul_mxfp": matmul_mxfp,
"matmul_mxint": matmul_mxint,
"bmm_mxfp": bmm_mxfp,
"bmm_mxint": bmm_mxint,
"relu_block_minifloat": relu_block_minifloat,
"relu_integer": relu_integer,
"relu_fixed": relu_integer,
Expand Down Expand Up @@ -277,4 +290,15 @@
"linear_ternary": linearTernary,
"linear_lutnet": linearLUT,
"linear_logicnets": linearLogicNets,
"softmax_mxfp": softmax_mxfp,
"softmax_mxint": softmax_mxint,
"softmax_minifloat": softmax_minifloat,
"silu_mxfp": silu_mxfp,
"silu_mxint": silu_mxint,
"silu_minifloat": silu_minifloat,
"rope_mxfp": rope_mxfp,
"rope_mxint": rope_mxint,
"rope_minifloat": rope_minifloat,
"kv_cache_mxfp": kv_cache_mxfp,
"kv_cache_mxint": kv_cache_mxint,
}
43 changes: 43 additions & 0 deletions src/chop/nn/quantized/functional/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from functools import partial

from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer


def kv_cache_mxfp(
key_states: Tensor,
value_states: Tensor,
config: dict = None,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

return x_quantizer(key_states), x_quantizer(value_states)


def kv_cache_mxint(
key_states: Tensor,
value_states: Tensor,
config: dict = None,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

return x_quantizer(key_states), x_quantizer(value_states)
3 changes: 3 additions & 0 deletions src/chop/nn/quantized/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
binary_quantizer,
ternary_quantizer,
mxint_hardware,
mxfp_quantizer,
mxint_quantizer,
)


Expand Down Expand Up @@ -581,3 +583,4 @@ def linearMXIntHardware(
if out_config is not None:
out = out_quantizer(out)
return out

80 changes: 80 additions & 0 deletions src/chop/nn/quantized/functional/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
minifloat_ieee_quantizer,
binary_quantizer,
ternary_quantizer,
mxfp_quantizer,
mxint_quantizer,
)

# PyTorch has torch.matmul and torch.bmm for matrix multiplication
Expand Down Expand Up @@ -430,3 +432,81 @@ def bmm_block_minifloat(x, y, config):

def bmm_block_log(x, y, config):
return generic_matmul_block_log(x, y, config, style="bmm")


def generic_matmul_mxfp(x, y, config, style="matmul"):
bypass = config.get("bypass", False)
matmul = matmul_mapping[style]
if bypass:
return matmul(x, y)

x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]
y_block_size = config["weight_block_size"]
y_exp_bits = config["weight_exponent_width"]
y_frac_bits = config["weight_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)
y_quantizer = partial(
mxfp_quantizer,
block_size=y_block_size,
element_exp_bits=y_exp_bits,
element_frac_bits=y_frac_bits,
block_dim=-1,
)

x = x_quantizer(x)
y = y_quantizer(y)
return matmul(x, y)


def generic_matmul_mxint(x, y, config, style="matmul"):
bypass = config.get("bypass", False)
matmul = matmul_mapping[style]
if bypass:
return matmul(x, y)

x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]
y_block_size = config["weight_block_size"]
y_element_bits = config["weight_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)
y_quantizer = partial(
mxint_quantizer,
block_size=y_block_size,
element_bits=y_element_bits,
block_dim=-1,
)

x = x_quantizer(x)
y = y_quantizer(y)
return matmul(x, y)


def matmul_mxfp(x, y, config):
return generic_matmul_mxfp(x, y, config, "matmul")


def matmul_mxint(x, y, config):
return generic_matmul_mxint(x, y, config, "matmul")


def bmm_mxfp(x, y, config):
return generic_matmul_mxfp(x, y, config, "bmm")


def bmm_mxint(x, y, config):
return generic_matmul_mxint(x, y, config, "bmm")
101 changes: 101 additions & 0 deletions src/chop/nn/quantized/functional/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from functools import partial

import torch
from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer
from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim


def rotate_half(x: Tensor) -> Tensor:
"""Rotate half the last dimension (for RoPE)."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def _apply_rope(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

seq_len = q.size(-2)
cos = cos[..., :seq_len, :]
sin = sin[..., :seq_len, :]

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def rope_mxfp(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)


def rope_mxint(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)


def rope_minifloat(
q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
config: dict = None,
unsqueeze_dim: int = 1,
) -> tuple[Tensor, Tensor]:
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
minifloat_quantizer_sim,
minifloat_meta=MinifloatMeta(
exp_bits=x_exp_bits,
frac_bits=x_frac_bits,
is_finite=config.get("data_in_is_finite", True),
round_mode=config.get("data_in_round_mode", "rn"),
),
)

cos = x_quantizer(cos)
sin = x_quantizer(sin)
return _apply_rope(q, k, cos, sin, unsqueeze_dim)
57 changes: 57 additions & 0 deletions src/chop/nn/quantized/functional/silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from functools import partial

import torch
from torch import Tensor

from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer
from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim


def silu_mxfp(x: Tensor, config: dict = None) -> Tensor:
x_block_size = config["data_in_block_size"]
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
mxfp_quantizer,
block_size=x_block_size,
element_exp_bits=x_exp_bits,
element_frac_bits=x_frac_bits,
block_dim=-1,
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)


def silu_mxint(x: Tensor, config: dict = None) -> Tensor:
x_block_size = config["data_in_block_size"]
x_element_bits = config["data_in_width"]

x_quantizer = partial(
mxint_quantizer,
block_size=x_block_size,
element_bits=x_element_bits,
block_dim=-1,
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)


def silu_minifloat(x: Tensor, config: dict = None) -> Tensor:
x_exp_bits = config["data_in_exponent_width"]
x_frac_bits = config["data_in_frac_width"]

x_quantizer = partial(
minifloat_quantizer_sim,
minifloat_meta=MinifloatMeta(
exp_bits=x_exp_bits,
frac_bits=x_frac_bits,
is_finite=config.get("data_in_is_finite", True),
round_mode=config.get("data_in_round_mode", "rn"),
),
)

x = x_quantizer(x)
return torch.nn.functional.silu(x)
Loading
Loading