From 56a5aa23fec3e640bc3e3789e5d622fc0b1aa7b0 Mon Sep 17 00:00:00 2001 From: jiaeenie Date: Tue, 10 Feb 2026 07:26:24 +0000 Subject: [PATCH 1/5] Port mx quantizers --- src/chop/nn/quantizers/__init__.py | 5 + .../nn/quantizers/_minifloat_mx/__init__.py | 15 ++ src/chop/nn/quantizers/_minifloat_mx/fake.py | 149 +++++++++++ src/chop/nn/quantizers/_minifloat_mx/meta.py | 42 ++++ src/chop/nn/quantizers/mxfp/__init__.py | 14 ++ src/chop/nn/quantizers/mxfp/fake.py | 104 ++++++++ src/chop/nn/quantizers/mxfp/helpers.py | 63 +++++ src/chop/nn/quantizers/mxfp/meta.py | 67 +++++ src/chop/nn/quantizers/mxfp/mxfp.py | 212 ++++++++++++++++ src/chop/nn/quantizers/mxint/__init__.py | 14 ++ src/chop/nn/quantizers/mxint/fake.py | 69 ++++++ src/chop/nn/quantizers/mxint/meta.py | 34 +++ src/chop/nn/quantizers/mxint/mxint.py | 232 ++++++++++++++++++ 13 files changed, 1020 insertions(+) create mode 100644 src/chop/nn/quantizers/_minifloat_mx/__init__.py create mode 100644 src/chop/nn/quantizers/_minifloat_mx/fake.py create mode 100644 src/chop/nn/quantizers/_minifloat_mx/meta.py create mode 100644 src/chop/nn/quantizers/mxfp/__init__.py create mode 100644 src/chop/nn/quantizers/mxfp/fake.py create mode 100644 src/chop/nn/quantizers/mxfp/helpers.py create mode 100644 src/chop/nn/quantizers/mxfp/meta.py create mode 100644 src/chop/nn/quantizers/mxfp/mxfp.py create mode 100644 src/chop/nn/quantizers/mxint/__init__.py create mode 100644 src/chop/nn/quantizers/mxint/fake.py create mode 100644 src/chop/nn/quantizers/mxint/meta.py create mode 100644 src/chop/nn/quantizers/mxint/mxint.py diff --git a/src/chop/nn/quantizers/__init__.py b/src/chop/nn/quantizers/__init__.py index c5b1f4c8d..42ad5e8f0 100644 --- a/src/chop/nn/quantizers/__init__.py +++ b/src/chop/nn/quantizers/__init__.py @@ -8,6 +8,9 @@ from .minifloat import minifloat_denorm_quantizer, minifloat_ieee_quantizer from .quantizers_for_hw import integer_quantizer_for_hw, integer_floor_quantizer_for_hw from .mxint_hardware import mxint_hardware +from .mxfp import mxfp_quantizer +from .mxint import mxint_quantizer + quantizer_map = { "log": log_quantizer, @@ -20,4 +23,6 @@ "binary": binary_quantizer, "ternary": ternary_quantizer, "mxint_hardware": mxint_hardware, + "mxfp": mxfp_quantizer, + "mxint": mxint_quantizer, } diff --git a/src/chop/nn/quantizers/_minifloat_mx/__init__.py b/src/chop/nn/quantizers/_minifloat_mx/__init__.py new file mode 100644 index 000000000..053f7f8f0 --- /dev/null +++ b/src/chop/nn/quantizers/_minifloat_mx/__init__.py @@ -0,0 +1,15 @@ +""" +Internal minifloat module for MX-format quantizers. + +This is used internally by MXFP. +""" + +from .meta import MinifloatMeta, MinifloatTensorMeta +from .fake import extract_minifloat_component, compose_minifloat_component + +__all__ = [ + "MinifloatMeta", + "MinifloatTensorMeta", + "extract_minifloat_component", + "compose_minifloat_component", +] diff --git a/src/chop/nn/quantizers/_minifloat_mx/fake.py b/src/chop/nn/quantizers/_minifloat_mx/fake.py new file mode 100644 index 000000000..0b2c85ae9 --- /dev/null +++ b/src/chop/nn/quantizers/_minifloat_mx/fake.py @@ -0,0 +1,149 @@ +""" +Fake minifloat quantization operations. +""" + +import torch +from torch import Tensor + +from .meta import MinifloatMeta + + +def extract_minifloat_component(x: Tensor, minifloat_meta: MinifloatMeta) -> Tensor: + """ + Extract minifloat representation from float tensor. + + Args: + x: Input float tensor + minifloat_meta: Minifloat format specification + + Returns: + Tensor of uint16 containing minifloat representation + """ + y_exp_bits = minifloat_meta.exp_bits + y_frac_bits = minifloat_meta.frac_bits + always_finite = minifloat_meta.is_finite + round_mode = minifloat_meta.round_mode + + y_exp_bias = (1 << (y_exp_bits - 1)) - 1 + y_exp_max = (1 << y_exp_bits) - 1 if always_finite else (1 << y_exp_bits) - 2 + y_exp_max_biased = y_exp_max - y_exp_bias + y_exp_min = 0 + y_exp_min_biased = y_exp_min - y_exp_bias + y_frac_max = (1 << y_frac_bits) - 1 + + x = x.to(torch.float32) + y_sign = x < 0 + x_int32 = x.abs().view(torch.int32) + flush_to_zero = (x_int32 & 0x7F800000) == 0 + x_normal = torch.where(flush_to_zero, 0.0, x) + + x_frac, x_exp = x_normal.abs().frexp() + x_frac = x_frac * 2 + x_exp = x_exp - 1 + + if not always_finite: + x_is_inf = x.isinf() + x_is_nan = x.isnan() + + y_exp = x_exp + underflow = y_exp < y_exp_min_biased + overflow = y_exp > y_exp_max_biased + y_exp = y_exp + y_exp_bias + + y_frac = x_frac.view(torch.int32) & 0x7FFFFF + + if round_mode == "rz": + y_frac = y_frac >> (23 - y_frac_bits) + else: + y_frac = (y_frac >> 8).float() + div = 1 << (15 - y_frac_bits) + y_frac = y_frac / div + if round_mode == "ru": + y_frac = y_frac.ceil() + elif round_mode == "rd": + y_frac = y_frac.floor() + elif round_mode == "rn": + y_frac = y_frac.round() + else: + raise ValueError(f"Unknown rounding mode: {round_mode}") + y_frac = y_frac.to(torch.int32) + + y_is_subnormal = (y_exp == y_exp_min) & (y_frac != 0) + y_frac = torch.where(y_is_subnormal, (y_frac | (1 << y_frac_bits)) >> 1, y_frac) + + # underflow -> 0 + y_frac = torch.where(underflow, 0, y_frac) + y_exp = torch.where(underflow, 0, y_exp) + # overflow -> max + y_frac = torch.where(overflow, y_frac_max, y_frac) + y_exp = torch.where(overflow, y_exp_max, y_exp) + # flush to zero + y_frac = torch.where(flush_to_zero, 0, y_frac) + y_exp = torch.where(flush_to_zero, 0, y_exp) + + if not always_finite: + y_frac = torch.where(x_is_inf, 0, y_frac) + y_frac = torch.where(x_is_nan, (1 << y_frac_bits) - 1, y_frac) + y_exp = torch.where(x_is_inf, y_exp_max, y_exp) + y_exp = torch.where(x_is_nan, y_exp_max, y_exp) + + y = (y_exp << y_frac_bits) | y_frac + y = torch.where(y_sign, y + (1 << (y_exp_bits + y_frac_bits)), y) + y = y.to(torch.uint16) + return y + + +def compose_minifloat_component( + elements: Tensor, + minifloat_meta: MinifloatMeta, + output_dtype: torch.dtype, +) -> Tensor: + """ + Compose float tensor from minifloat representation. + + Args: + elements: Tensor of uint16 containing minifloat representation + minifloat_meta: Minifloat format specification + output_dtype: Desired output dtype + + Returns: + Dequantized float tensor + """ + exp_bits = minifloat_meta.exp_bits + frac_bits = minifloat_meta.frac_bits + always_finite = minifloat_meta.is_finite + + x_sign_mask = 1 << (exp_bits + frac_bits) + x_frac_mask = (1 << frac_bits) - 1 + x_exp_bias = (1 << (exp_bits - 1)) - 1 + + assert elements.dtype == torch.uint16 + elements = elements.to(torch.int32) + y_sign = (elements & x_sign_mask) << (31 - (exp_bits + frac_bits)) + + elements = elements & 0x7FFF + x_exp = (elements >> frac_bits) & ((1 << exp_bits) - 1) + x_frac = elements & x_frac_mask + is_subnormal = (x_exp == 0) & (x_frac != 0) + is_zero = (x_exp == 0) & (x_frac == 0) + + if not always_finite: + y_is_not_finite = x_exp == ((1 << exp_bits) - 1) + y_is_inf = y_is_not_finite & (x_frac == 0) + y_is_nan = y_is_not_finite & (x_frac != 0) + + y_exp = x_exp - x_exp_bias + y_exp = torch.where(is_subnormal, y_exp + 1, y_exp) + y_exp = torch.exp2(y_exp) + y_frac = x_frac.to(torch.float32) + y_frac = y_frac / (1 << frac_bits) + y_frac = torch.where(is_subnormal, y_frac, y_frac + 1.0) + y = y_exp * y_frac + + if not always_finite: + y = torch.where(y_is_inf, float("inf"), y) + y = torch.where(y_is_nan, float("nan"), y) + y = torch.where(is_zero, 0.0, y) + y = torch.where(y_sign != 0, -y, y) + y = y.to(output_dtype) + return y diff --git a/src/chop/nn/quantizers/_minifloat_mx/meta.py b/src/chop/nn/quantizers/_minifloat_mx/meta.py new file mode 100644 index 000000000..94e48f9b8 --- /dev/null +++ b/src/chop/nn/quantizers/_minifloat_mx/meta.py @@ -0,0 +1,42 @@ +""" +Minifloat metadata for MX-format quantizers. +""" + +import functools +from dataclasses import dataclass +from typing import Literal + + +@dataclass(frozen=True) +class MinifloatMeta: + """ + Metadata for minifloat types. + + Args: + exp_bits: Number of exponent bits + frac_bits: Number of fraction bits + is_finite: Whether the minifloat type supports inf/nan + round_mode: Rounding mode - "rn" (nearest), "rd" (down), "ru" (up), "rz" (truncate) + """ + + exp_bits: int + frac_bits: int + is_finite: bool + round_mode: Literal["rn", "rd", "ru", "rz"] + + def __post_init__(self): + assert self.exp_bits > 0 + assert self.frac_bits > 0 + assert self.exp_bits + self.frac_bits < 16 + + @functools.cached_property + def n_bits(self) -> int: + return self.exp_bits + self.frac_bits + 1 + + +@dataclass +class MinifloatTensorMeta: + device: str + dtype: str + shape: tuple[int, ...] + meta: MinifloatMeta diff --git a/src/chop/nn/quantizers/mxfp/__init__.py b/src/chop/nn/quantizers/mxfp/__init__.py new file mode 100644 index 000000000..9005c584c --- /dev/null +++ b/src/chop/nn/quantizers/mxfp/__init__.py @@ -0,0 +1,14 @@ +""" +MXFP quantizer module. +""" + +from .meta import MXFPMeta, MXFPTensorMeta +from .mxfp import mxfp_quantizer, mxfp_quantizer_sim, MXFPQuantize + +__all__ = [ + "MXFPMeta", + "MXFPTensorMeta", + "mxfp_quantizer", + "mxfp_quantizer_sim", + "MXFPQuantize", +] diff --git a/src/chop/nn/quantizers/mxfp/fake.py b/src/chop/nn/quantizers/mxfp/fake.py new file mode 100644 index 000000000..5a1930bc2 --- /dev/null +++ b/src/chop/nn/quantizers/mxfp/fake.py @@ -0,0 +1,104 @@ +""" +Fake MXFP quantization operations. +""" + +import torch +from torch import Tensor + +from .._minifloat_mx import extract_minifloat_component, compose_minifloat_component +from .meta import MXFPMeta + + +def extract_mxfp_components( + tensor: Tensor, mxfp_meta: MXFPMeta, percentile: float = 1.0 +) -> tuple[Tensor, Tensor]: + """ + Extract MXFP components (scales and elements) from a tensor. + + Args: + tensor: Input tensor (already flattened to [n_blocks, block_size]) + mxfp_meta: MXFP format specification + percentile: Percentile for scale calculation (1.0 = max) + + Returns: + Tuple of (scales_uint8, elements_uint8) + """ + tensor = tensor.float() + B = mxfp_meta.block_size + assert tensor.numel() % B == 0 + + n_blocks = tensor.numel() // B + + fp32_exp_mask = 0x7F800000 + + sc_exp_max = (1 << mxfp_meta.scale_exp_bits) - 1 + sc_exp_min = 0 + sc_exp_bias = (1 << (mxfp_meta.scale_exp_bits - 1)) - 1 + sc_exp_max_biased = sc_exp_max - sc_exp_bias + sc_exp_min_biased = sc_exp_min - sc_exp_bias + + el_exp_max = ( + (1 << mxfp_meta.element_exp_bits) - 1 + if mxfp_meta.element_is_finite + else (1 << mxfp_meta.element_exp_bits) - 2 + ) + el_exp_bias = (1 << (mxfp_meta.element_exp_bits - 1)) - 1 + el_exp_max_biased = el_exp_max - el_exp_bias + + tensor = tensor.flatten() + tensor = tensor.reshape(n_blocks, B) + + x_int32 = tensor.view(torch.int32) + # flush subnormal to zero + flush_to_zero = (x_int32 & fp32_exp_mask) == 0 + tensor = torch.where(flush_to_zero, 0.0, tensor) + + shared_exp = tensor.abs().quantile(percentile, dim=1, keepdim=True) + + shared_exp = shared_exp.log2().floor().to(torch.int32) + shared_exp -= el_exp_max_biased + shared_exp = shared_exp.clamp(sc_exp_min_biased, sc_exp_max_biased) + scales_uint = shared_exp + sc_exp_bias + scales_uint = torch.where(flush_to_zero.all(dim=1, keepdim=True), 0, scales_uint) + scales_uint = scales_uint.to(torch.uint8) + + scales_fp = torch.exp2(shared_exp) + + minifloats = torch.where(flush_to_zero, 0.0, tensor / scales_fp) + elements = extract_minifloat_component(minifloats, mxfp_meta.element_meta) + elements = elements.view(torch.uint16) + elements = elements.to(torch.uint8) + + return scales_uint, elements + + +def compose_mxfp_tensor( + scales: Tensor, + elements: Tensor, + mxfp_meta: MXFPMeta, + output_dtype: torch.dtype, +) -> Tensor: + """ + Compose tensor from MXFP components. + + Args: + scales: Shared scales (uint8) + elements: Quantized elements (uint8) + mxfp_meta: MXFP format specification + output_dtype: Desired output dtype + + Returns: + Dequantized tensor + """ + assert scales.dtype == torch.uint8 + assert elements.dtype == torch.uint8 + + sc_exp_bias = (1 << (mxfp_meta.scale_exp_bits - 1)) - 1 + scales_fp = torch.exp2(scales.to(torch.int32) - sc_exp_bias) + minifloats = compose_minifloat_component( + elements.to(torch.uint16), mxfp_meta.element_meta, output_dtype=torch.float32 + ) + + dequantized = minifloats * scales_fp + dequantized = dequantized.flatten().to(output_dtype) + return dequantized diff --git a/src/chop/nn/quantizers/mxfp/helpers.py b/src/chop/nn/quantizers/mxfp/helpers.py new file mode 100644 index 000000000..1c6deb197 --- /dev/null +++ b/src/chop/nn/quantizers/mxfp/helpers.py @@ -0,0 +1,63 @@ +""" +Helper functions for MX-format quantizers. +""" + +from torch import Tensor + + +def flatten_for_quantize(tensor: Tensor, block_dim: int) -> Tensor: + """ + Permute tensor to move block dimension to last position and flatten. + + Args: + tensor: Input tensor + block_dim: Dimension to use for blocking + + Returns: + Flattened tensor with block_dim moved to last position + """ + ori_shape = tuple(tensor.shape) + ndim = len(ori_shape) + block_dim = block_dim % ndim + + # Create permutation to move block_dim to last position + permute = list(range(ndim)) + permute.append(permute.pop(block_dim)) + + tensor = tensor.permute(permute) + tensor = tensor.flatten() + return tensor + + +def permute_for_dequantize( + flatten_tensor: Tensor, + ori_shape: tuple[int, ...], + block_dim: int, +) -> Tensor: + """ + Reshape flattened tensor back to original shape after dequantization. + + Args: + flatten_tensor: Flattened tensor from quantization + ori_shape: Original tensor shape before flattening + block_dim: Original block dimension + + Returns: + Tensor restored to original shape + """ + ndim = len(ori_shape) + block_dim = block_dim % ndim + + # Create the shape after moving block_dim to last position + permuted_shape = list(ori_shape) + permuted_shape.append(permuted_shape.pop(block_dim)) + + # Reshape from flattened form to intermediate permuted form + tensor = flatten_tensor.reshape(permuted_shape) + + # Create inverse permutation to restore original dimension order + inverse_permute = list(range(ndim)) + inverse_permute.insert(block_dim, inverse_permute.pop(-1)) + + tensor = tensor.permute(inverse_permute) + return tensor diff --git a/src/chop/nn/quantizers/mxfp/meta.py b/src/chop/nn/quantizers/mxfp/meta.py new file mode 100644 index 000000000..fcd66bf26 --- /dev/null +++ b/src/chop/nn/quantizers/mxfp/meta.py @@ -0,0 +1,67 @@ +""" +MXFP metadata classes. +""" + +import functools +from dataclasses import dataclass +from typing import Literal + +from .._minifloat_mx import MinifloatMeta + + +@dataclass(frozen=True) +class MXFPMeta: + """ + Metadata for MXFP (Mixed-exponent Floating Point) format. + + MXFP uses block-wise shared exponents with minifloat elements. + + Args: + block_size: Number of elements per block for shared exponent + scale_exp_bits: Bits for shared scale exponent (typically 8) + element_exp_bits: Exponent bits per element (e.g., 4 for E4M3) + element_frac_bits: Fraction bits per element (e.g., 3 for E4M3) + element_is_finite: Whether elements support inf/nan + round_mode: Rounding mode + """ + + block_size: int + scale_exp_bits: int + element_exp_bits: int + element_frac_bits: int + element_is_finite: bool + round_mode: Literal["rn", "ru", "rd", "rz"] + + def __post_init__(self): + legal_scale_exp_bits = (8,) + assert self.scale_exp_bits in legal_scale_exp_bits, ( + f"Invalid scale exponent bits: {self.scale_exp_bits}. " + f"Legal values are: {legal_scale_exp_bits}." + ) + legal_element_exp_frac_bits = ((4, 3), (5, 2), (2, 3), (3, 2), (2, 1), (1, 2)) + el_exp_frac = (self.element_exp_bits, self.element_frac_bits) + assert el_exp_frac in legal_element_exp_frac_bits, ( + f"Invalid element exp/frac bits: {el_exp_frac}. " + f"Legal values are: {legal_element_exp_frac_bits}." + ) + + @functools.cached_property + def element_meta(self) -> MinifloatMeta: + """Returns MinifloatMeta for the element part of MXFP format.""" + return MinifloatMeta( + exp_bits=self.element_exp_bits, + frac_bits=self.element_frac_bits, + is_finite=self.element_is_finite, + round_mode=self.round_mode, + ) + + +@dataclass(frozen=True) +class MXFPTensorMeta: + """Runtime metadata for an MXFP tensor.""" + + device: str + dtype: str + shape: tuple[int, ...] + block_dim: int + meta: MXFPMeta diff --git a/src/chop/nn/quantizers/mxfp/mxfp.py b/src/chop/nn/quantizers/mxfp/mxfp.py new file mode 100644 index 000000000..715250ee9 --- /dev/null +++ b/src/chop/nn/quantizers/mxfp/mxfp.py @@ -0,0 +1,212 @@ +""" +MXFP quantizer. +""" + +import torch +from torch import Tensor +from tqdm import tqdm + +from .meta import MXFPMeta, MXFPTensorMeta +from .helpers import flatten_for_quantize, permute_for_dequantize +from .fake import extract_mxfp_components, compose_mxfp_tensor + + +def mxfp_quantizer_sim( + tensor: Tensor, + block_dim: int, + mxfp_meta: MXFPMeta, + act_tensor: Tensor | None = None, + dtype: torch.dtype | None = None, + quantile_search: bool = False, + cali_batch_size: int = 32, +) -> Tensor: + """ + Quantize and dequantize a tensor using MXFP format. + + Args: + tensor: Input tensor to quantize + block_dim: Dimension to apply block quantization + mxfp_meta: MXFP format specification + act_tensor: Optional activation tensor for GPTQ-style calibration + dtype: Output dtype (default: same as input) + quantile_search: Enable quantile-based clipping search + cali_batch_size: Batch size for calibration + + Returns: + Dequantized tensor + """ + out_dq = torch.zeros_like(tensor) + + if quantile_search: + qtensor = tensor.flatten() + B = mxfp_meta.block_size + + qtensor = qtensor.reshape(-1, B) + best = torch.full([qtensor.shape[0]], float('inf'), device=tensor.device, dtype=tensor.dtype) + best_scales, best_elements, tensor_meta = _extract_with_meta( + tensor, block_dim, mxfp_meta, percentile=1.0 + ) + + percentiles = [1.0, 0.995, 0.99, 0.97, 0.95, 0.93, 0.90, 0.80, 0.70, 0.60, 0.50] + for percentile in percentiles: + scales, elements, tensor_meta = _extract_with_meta( + tensor, block_dim, mxfp_meta, percentile=percentile + ) + scale_bias = 2 ** (mxfp_meta.scale_exp_bits - 1) - 1 + q = elements / 2 ** (mxfp_meta.element_frac_bits - 1) * 2 ** (scales - scale_bias) + q = q.to(dtype=qtensor.dtype) + + if act_tensor is not None: + BATCH_SIZE = cali_batch_size + last_dim = act_tensor.shape[-1] + if last_dim != B: + assert last_dim % B == 0 + act_tensor = act_tensor.view(*act_tensor.shape[:-1], last_dim // B, B) + + total_batches = act_tensor.shape[0] + err = torch.zeros(qtensor.shape[0], device=tensor.device, dtype=tensor.dtype) + + with torch.no_grad(): + for b in tqdm(range(0, total_batches, BATCH_SIZE), desc="Batching quant output", disable=True): + act_b = act_tensor[b:b + BATCH_SIZE] + out_q = torch.matmul(act_b, q.T) + out_orig = torch.matmul(act_b, qtensor.T) + err += torch.norm(out_q - out_orig, p=2, dim=(0, 1)) + + del act_b, out_q, out_orig + torch.cuda.empty_cache() + + torch.cuda.empty_cache() + else: + q -= qtensor + q.abs_() + q.pow_(2) + err = torch.sum(q, 1) + + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + best_scales[tmp] = scales[tmp] + best_elements[tmp] = elements[tmp] + + else: + best_scales, best_elements, tensor_meta = _extract_with_meta( + tensor, block_dim, mxfp_meta, percentile=1.0 + ) + + out_dq = compose_mxfp_tensor(best_scales, best_elements, tensor_meta.meta, output_dtype=dtype or tensor.dtype) + out_dq = permute_for_dequantize(out_dq, tensor_meta.shape, tensor_meta.block_dim) + return out_dq + + +def _extract_with_meta( + tensor: Tensor, + block_dim: int, + mxfp_meta: MXFPMeta, + percentile: float = 1.0, +) -> tuple[Tensor, Tensor, MXFPTensorMeta]: + """Extract MXFP components with tensor metadata.""" + device = str(tensor.device) + ori_shape = tuple(tensor.shape) + ori_dtype = str(tensor.dtype).removeprefix("torch.") + ndim = len(ori_shape) + assert block_dim < ndim and block_dim >= -ndim + + tensor_flat = flatten_for_quantize(tensor, block_dim) + scales, elements = extract_mxfp_components(tensor_flat, mxfp_meta, percentile=percentile) + + tensor_meta = MXFPTensorMeta( + device=device, + dtype=ori_dtype, + shape=ori_shape, + block_dim=block_dim, + meta=mxfp_meta, + ) + return scales, elements, tensor_meta + + +# ============================================================================= +# Mase-style quantizer interface with STE +# ============================================================================= + + +class MXFPQuantize(torch.autograd.Function): + """Autograd function for MXFP quantization with STE gradient.""" + + @staticmethod + def forward( + ctx, + x: Tensor, + block_size: int, + element_exp_bits: int, + element_frac_bits: int, + block_dim: int, + scale_exp_bits: int, + quantile_search: bool, + ) -> Tensor: + meta = MXFPMeta( + block_size=block_size, + scale_exp_bits=scale_exp_bits, + element_exp_bits=element_exp_bits, + element_frac_bits=element_frac_bits, + element_is_finite=True, + round_mode="rn", + ) + return mxfp_quantizer_sim( + tensor=x, + block_dim=block_dim, + mxfp_meta=meta, + quantile_search=quantile_search, + ) + + @staticmethod + def backward(ctx, grad_output): + # STE: pass gradient through unchanged + grad_input = grad_output.clone() + return grad_input, None, None, None, None, None, None + + +def mxfp_quantizer( + x: Tensor, + block_size: int, + element_exp_bits: int, + element_frac_bits: int, + block_dim: int = -1, + scale_exp_bits: int = 8, + quantile_search: bool = False, +) -> Tensor: + """ + MXFP quantizer with mase-style interface. + + Converts tensor to MXFP format with block-wise shared exponent + and minifloat elements, then dequantizes back. + + Args: + x: Input tensor to quantize + block_size: Number of elements per block for shared exponent (e.g., 32) + element_exp_bits: Exponent bits for each element (e.g., 4 for E4M3) + element_frac_bits: Fraction bits for each element (e.g., 3 for E4M3) + block_dim: Dimension to apply block quantization (-1 for last dim) + scale_exp_bits: Bits for shared scale exponent (default 8) + quantile_search: Enable quantile-based clipping search + + Returns: + Quantized tensor in dequantized form + + Example: + >>> x = torch.randn(4, 32) + >>> q = mxfp_quantizer(x, block_size=32, element_exp_bits=4, element_frac_bits=3) + + Common formats: + - E4M3: element_exp_bits=4, element_frac_bits=3 (8-bit element) + - E5M2: element_exp_bits=5, element_frac_bits=2 (8-bit element) + """ + return MXFPQuantize.apply( + x, + block_size, + element_exp_bits, + element_frac_bits, + block_dim, + scale_exp_bits, + quantile_search, + ) diff --git a/src/chop/nn/quantizers/mxint/__init__.py b/src/chop/nn/quantizers/mxint/__init__.py new file mode 100644 index 000000000..c1a10d5ee --- /dev/null +++ b/src/chop/nn/quantizers/mxint/__init__.py @@ -0,0 +1,14 @@ +""" +MXINT (Mixed-exponent Integer) quantizer module. +""" + +from .meta import MXIntMeta, MXIntTensorMeta +from .mxint import mxint_quantizer, mxint_quantizer_sim, MXIntQuantize + +__all__ = [ + "MXIntMeta", + "MXIntTensorMeta", + "mxint_quantizer", + "mxint_quantizer_sim", + "MXIntQuantize", +] diff --git a/src/chop/nn/quantizers/mxint/fake.py b/src/chop/nn/quantizers/mxint/fake.py new file mode 100644 index 000000000..3f430d57c --- /dev/null +++ b/src/chop/nn/quantizers/mxint/fake.py @@ -0,0 +1,69 @@ +""" +Fake MXINT quantization operations. +""" + +import torch +from torch import Tensor + +from .meta import MXIntMeta + + +def extract_mxint_components( + x: Tensor, mxint_meta: MXIntMeta, percentile: float = 1.0 +) -> tuple[Tensor, Tensor]: + """ + Extract MXINT components (scale and elements) from a tensor. + + Args: + x: Input tensor (already flattened) + mxint_meta: MXINT format specification + percentile: Percentile for scale calculation (1.0 = max) + + Returns: + Tuple of (scale, quantized_mantissa) + """ + B = mxint_meta.block_size + assert x.numel() % B == 0, ( + f"Input tensor size {x.numel()} is not divisible by block size {B}." + ) + n_blocks = x.numel() // B + + x = x.flatten() + x = x.reshape(n_blocks, B) + + ori_dtype = x.dtype + # quantile needs fp32 + x_max = x.abs().to(torch.float32).quantile(percentile, dim=1, keepdim=True).to(ori_dtype) + + scale = x_max.log2().ceil() + scale_bias = 2 ** (mxint_meta.scale_bits - 1) - 1 + x = x / 2 ** scale + x_mant = x * 2 ** (mxint_meta.element_bits - 1) + scale = scale + scale_bias + scale = scale.clamp(min=0, max=2 ** mxint_meta.scale_bits - 1) + x_mant = x_mant.round().clamp( + min=-2 ** (mxint_meta.element_bits - 1), + max=2 ** (mxint_meta.element_bits - 1) - 1 + ) + + return scale, x_mant + + +def compose_mxint_tensor( + shared_scales: Tensor, + elements: Tensor, + mxint_meta: MXIntMeta, +) -> Tensor: + """ + Compose tensor from MXINT components. + + Args: + shared_scales: Shared scales tensor + elements: Quantized elements tensor + mxint_meta: MXINT format specification + + Returns: + Dequantized tensor + """ + scale_bias = 2 ** (mxint_meta.scale_bits - 1) - 1 + return elements / 2 ** (mxint_meta.element_bits - 1) * 2 ** (shared_scales - scale_bias) diff --git a/src/chop/nn/quantizers/mxint/meta.py b/src/chop/nn/quantizers/mxint/meta.py new file mode 100644 index 000000000..1ed4d9a88 --- /dev/null +++ b/src/chop/nn/quantizers/mxint/meta.py @@ -0,0 +1,34 @@ +""" +MXINT metadata classes. +""" + +from dataclasses import dataclass + + +@dataclass +class MXIntMeta: + """ + Metadata for MXINT (Mixed-exponent Integer) format. + + MXINT uses block-wise shared scale with integer elements. + + Args: + block_size: Number of elements per block + scale_bits: Bits for shared scale (typically 8) + element_bits: Bits per element (e.g., 4 or 8) + """ + + block_size: int + scale_bits: int + element_bits: int + + +@dataclass +class MXIntTensorMeta: + """Runtime metadata for an MXINT tensor.""" + + device: str + dtype: str + shape: tuple[int, ...] + block_dim: int + meta: MXIntMeta diff --git a/src/chop/nn/quantizers/mxint/mxint.py b/src/chop/nn/quantizers/mxint/mxint.py new file mode 100644 index 000000000..800a582ac --- /dev/null +++ b/src/chop/nn/quantizers/mxint/mxint.py @@ -0,0 +1,232 @@ +""" +MXINT quantizer. +""" + +import torch +from torch import Tensor +from tqdm import tqdm + +from .meta import MXIntMeta, MXIntTensorMeta +from .fake import extract_mxint_components, compose_mxint_tensor +from ..mxfp.helpers import flatten_for_quantize, permute_for_dequantize + + +def mxint_quantizer_sim( + tensor: Tensor, + block_dim: int, + mxint_meta: MXIntMeta, + act_tensor: Tensor | None = None, + dtype: torch.dtype | None = None, + quantile_search: bool = False, + cali_batch_size: int = 32, +) -> Tensor: + """ + Quantize and dequantize a tensor using MXINT format. + + Args: + tensor: Input tensor to quantize + block_dim: Dimension to apply block quantization + mxint_meta: MXINT format specification + act_tensor: Optional activation tensor for GPTQ-style calibration + dtype: Output dtype (default: same as input) + quantile_search: Enable quantile-based clipping search + cali_batch_size: Batch size for calibration + + Returns: + Dequantized tensor + """ + tensor_dtype = tensor.dtype + + if quantile_search: + qtensor = tensor.flatten() + B = mxint_meta.block_size + + qtensor = qtensor.reshape(-1, B) + + percentiles = torch.tensor( + [1.0, 0.995, 0.99, 0.97, 0.95, 0.93, 0.90, 0.80, 0.70, 0.60, 0.50], + device=tensor.device, + dtype=torch.float32, + ) + + device = str(tensor.device) + ori_shape = tuple(tensor.shape) + ori_dtype = str(tensor.dtype).removeprefix("torch.") + ndim = len(ori_shape) + assert block_dim < ndim and block_dim >= -ndim + + tensor_flat = flatten_for_quantize(tensor, block_dim) + + x = tensor_flat + n_blocks = x.numel() // B + + x = x.flatten() + x = x.reshape(n_blocks, B) + + tem_dtype = x.dtype + x_max = x.abs().to(torch.float32).quantile(percentiles, dim=1, keepdim=True).to(tem_dtype) + + scale = x_max.log2().ceil() + scale_bias = 2 ** (mxint_meta.scale_bits - 1) - 1 + x = x / 2 ** scale + x_mant = x * 2 ** (mxint_meta.element_bits - 1) + scale = scale + scale_bias + scale = scale.clamp(min=0, max=2 ** mxint_meta.scale_bits - 1) + x_mant = x_mant.round().clamp( + min=-2 ** (mxint_meta.element_bits - 1), + max=2 ** (mxint_meta.element_bits - 1) - 1 + ) + + quant_tensor = x_mant / 2 ** (mxint_meta.element_bits - 1) * 2 ** (scale - scale_bias) + + del x, x_max, scale, x_mant + torch.cuda.empty_cache() + quant_tensor = quant_tensor.reshape(len(percentiles), n_blocks, B) + + if act_tensor is None: + err = torch.norm(quant_tensor - qtensor, p=2, dim=-1) + min_err_idx = torch.argmin(err, dim=0).reshape(-1) + else: + BATCH_SIZE = cali_batch_size + last_dim = act_tensor.shape[-1] + if last_dim != B: + assert last_dim % B == 0 + act_tensor = act_tensor.view(*act_tensor.shape[:-1], last_dim // B, B) + + total_batches = act_tensor.shape[0] + err = torch.zeros( + [percentiles.shape[0], qtensor.shape[0]], + device=tensor.device, + dtype=tensor.dtype, + ) + + with torch.no_grad(): + for b in tqdm(range(0, total_batches, BATCH_SIZE), desc="Batching quant output", disable=True): + act_b = act_tensor[b:b + BATCH_SIZE] + out_orig = torch.matmul(act_b, qtensor.T) + out_q = torch.einsum('asb,phb->pash', act_b, quant_tensor.to(act_tensor.dtype)) + err += torch.norm(out_q - out_orig, p=2, dim=(1, 2)) + + del act_b, out_q, out_orig + + min_err_idx = torch.argmin(err, dim=0) + torch.cuda.empty_cache() + + quant_tensor = quant_tensor[min_err_idx, torch.arange(quant_tensor.shape[1])] + + tensor_meta = MXIntTensorMeta( + device=device, + dtype=ori_dtype, + shape=ori_shape, + block_dim=block_dim, + meta=mxint_meta, + ) + + tensor_out = permute_for_dequantize( + quant_tensor, ori_shape=tensor_meta.shape, block_dim=tensor_meta.block_dim + ) + out_dq = tensor_out.to(tensor_dtype) + + else: + device = str(tensor.device) + ori_shape = tuple(tensor.shape) + ori_dtype = str(tensor.dtype).removeprefix("torch.") + ndim = len(ori_shape) + assert block_dim < ndim and block_dim >= -ndim + + tensor_flat = flatten_for_quantize(tensor, block_dim) + scales, elements = extract_mxint_components(tensor_flat, mxint_meta, percentile=1.0) + + tensor_meta = MXIntTensorMeta( + device=device, + dtype=ori_dtype, + shape=ori_shape, + block_dim=block_dim, + meta=mxint_meta, + ) + + dequant = compose_mxint_tensor(scales, elements, mxint_meta) + out_dq = permute_for_dequantize(dequant, tensor_meta.shape, tensor_meta.block_dim) + out_dq = out_dq.to(dtype or tensor_dtype) + + return out_dq + + +# ============================================================================= +# Mase-style quantizer interface with STE +# ============================================================================= + + +class MXIntQuantize(torch.autograd.Function): + """Autograd function for MXINT quantization with STE gradient.""" + + @staticmethod + def forward( + ctx, + x: Tensor, + block_size: int, + element_bits: int, + block_dim: int, + scale_bits: int, + quantile_search: bool, + ) -> Tensor: + meta = MXIntMeta( + block_size=block_size, + scale_bits=scale_bits, + element_bits=element_bits, + ) + return mxint_quantizer_sim( + tensor=x, + block_dim=block_dim, + mxint_meta=meta, + quantile_search=quantile_search, + ) + + @staticmethod + def backward(ctx, grad_output): + # STE: pass gradient through unchanged + grad_input = grad_output.clone() + return grad_input, None, None, None, None, None + + +def mxint_quantizer( + x: Tensor, + block_size: int, + element_bits: int, + block_dim: int = -1, + scale_bits: int = 8, + quantile_search: bool = False, +) -> Tensor: + """ + MXINT quantizer with mase-style interface. + + Converts tensor to MXINT format with block-wise shared scale + and integer elements, then dequantizes back. + + Args: + x: Input tensor to quantize + block_size: Number of elements per block (e.g., 32) + element_bits: Bits per element (e.g., 4 or 8) + block_dim: Dimension to apply block quantization (-1 for last dim) + scale_bits: Bits for shared scale (default 8) + quantile_search: Enable quantile-based clipping search + + Returns: + Quantized tensor in dequantized form + + Example: + >>> x = torch.randn(4, 32) + >>> q = mxint_quantizer(x, block_size=32, element_bits=8) + + Common formats: + - MXINT8: element_bits=8 + - MXINT4: element_bits=4 + """ + return MXIntQuantize.apply( + x, + block_size, + element_bits, + block_dim, + scale_bits, + quantile_search, + ) From f07e50cc3948dea0ede58e1641a3e65f2308e5f0 Mon Sep 17 00:00:00 2001 From: jiaeenie Date: Thu, 19 Feb 2026 12:03:36 +0000 Subject: [PATCH 2/5] Add mx quantized functional operations --- src/chop/nn/quantized/functional/__init__.py | 24 +++++ src/chop/nn/quantized/functional/kvcache.py | 43 ++++++++ src/chop/nn/quantized/functional/linear.py | 88 ++++++++++++++++ src/chop/nn/quantized/functional/matmul.py | 80 +++++++++++++++ src/chop/nn/quantized/functional/rope.py | 101 +++++++++++++++++++ src/chop/nn/quantized/functional/silu.py | 57 +++++++++++ src/chop/nn/quantized/functional/softmax.py | 57 +++++++++++ 7 files changed, 450 insertions(+) create mode 100644 src/chop/nn/quantized/functional/kvcache.py create mode 100644 src/chop/nn/quantized/functional/rope.py create mode 100644 src/chop/nn/quantized/functional/silu.py create mode 100644 src/chop/nn/quantized/functional/softmax.py diff --git a/src/chop/nn/quantized/functional/__init__.py b/src/chop/nn/quantized/functional/__init__.py index da30a9654..6b026f157 100644 --- a/src/chop/nn/quantized/functional/__init__.py +++ b/src/chop/nn/quantized/functional/__init__.py @@ -42,6 +42,8 @@ bmm_minifloat_ieee, bmm_binary, bmm_ternary, + bmm_mxfp, + bmm_mxint, matmul_block_fp, matmul_block_log, matmul_block_minifloat, @@ -51,7 +53,14 @@ matmul_minifloat_ieee, matmul_binary, matmul_ternary, + matmul_mxfp, + matmul_mxint, ) +from .softmax import softmax_mxfp, softmax_mxint, softmax_minifloat +from .silu import silu_mxfp, silu_mxint, silu_minifloat +from .rope import rope_mxfp, rope_mxint, rope_minifloat +from .kvcache import kv_cache_mxfp, kv_cache_mxint + from .mult import ( mult_block_fp, mult_block_log, @@ -203,6 +212,10 @@ "matmul_block_log": matmul_block_log, "matmul_binary": matmul_binary, "matmul_ternary": matmul_ternary, + "matmul_mxfp": matmul_mxfp, + "matmul_mxint": matmul_mxint, + "bmm_mxfp": bmm_mxfp, + "bmm_mxint": bmm_mxint, "relu_block_minifloat": relu_block_minifloat, "relu_integer": relu_integer, "relu_fixed": relu_integer, @@ -277,4 +290,15 @@ "linear_ternary": linearTernary, "linear_lutnet": linearLUT, "linear_logicnets": linearLogicNets, + "softmax_mxfp": softmax_mxfp, + "softmax_mxint": softmax_mxint, + "softmax_minifloat": softmax_minifloat, + "silu_mxfp": silu_mxfp, + "silu_mxint": silu_mxint, + "silu_minifloat": silu_minifloat, + "rope_mxfp": rope_mxfp, + "rope_mxint": rope_mxint, + "rope_minifloat": rope_minifloat, + "kv_cache_mxfp": kv_cache_mxfp, + "kv_cache_mxint": kv_cache_mxint, } diff --git a/src/chop/nn/quantized/functional/kvcache.py b/src/chop/nn/quantized/functional/kvcache.py new file mode 100644 index 000000000..99e7874ee --- /dev/null +++ b/src/chop/nn/quantized/functional/kvcache.py @@ -0,0 +1,43 @@ +from functools import partial + +from torch import Tensor + +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer + + +def kv_cache_mxfp( + key_states: Tensor, + value_states: Tensor, + config: dict = None, +) -> tuple[Tensor, Tensor]: + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + + return x_quantizer(key_states), x_quantizer(value_states) + + +def kv_cache_mxint( + key_states: Tensor, + value_states: Tensor, + config: dict = None, +) -> tuple[Tensor, Tensor]: + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + + return x_quantizer(key_states), x_quantizer(value_states) diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index 2c9f4caec..00b4cbc58 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -16,6 +16,8 @@ binary_quantizer, ternary_quantizer, mxint_hardware, + mxfp_quantizer, + mxint_quantizer, ) @@ -581,3 +583,89 @@ def linearMXIntHardware( if out_config is not None: out = out_quantizer(out) return out + + +def linearMXFP( + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, +): + w_block_size = config["weight_block_size"] + w_exp_bits = config["weight_exponent_width"] + w_frac_bits = config["weight_frac_width"] + + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + b_block_size = config["bias_block_size"] + b_exp_bits = config["bias_exponent_width"] + b_frac_bits = config["bias_frac_width"] + + w_quantizer = partial( + mxfp_quantizer, + block_size=w_block_size, + element_exp_bits=w_exp_bits, + element_frac_bits=w_frac_bits, + block_dim=-1, + ) + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + b_quantizer = partial( + mxfp_quantizer, + block_size=b_block_size, + element_exp_bits=b_exp_bits, + element_frac_bits=b_frac_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + weight = w_quantizer(weight) + bias = b_quantizer(bias) if bias is not None else None + return F.linear(x, weight, bias) + + +def linearMXInt( + x: Tensor, + weight: Tensor, + bias: Tensor = None, + config: dict = None, +): + w_block_size = config["weight_block_size"] + w_element_bits = config["weight_width"] + + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + + b_block_size = config["bias_block_size"] + b_element_bits = config["bias_width"] + + w_quantizer = partial( + mxint_quantizer, + block_size=w_block_size, + element_bits=w_element_bits, + block_dim=-1, + ) + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + b_quantizer = partial( + mxint_quantizer, + block_size=b_block_size, + element_bits=b_element_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + weight = w_quantizer(weight) + bias = b_quantizer(bias) if bias is not None else None + return F.linear(x, weight, bias) diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py index d06eb1ece..1ba1b4c16 100644 --- a/src/chop/nn/quantized/functional/matmul.py +++ b/src/chop/nn/quantized/functional/matmul.py @@ -15,6 +15,8 @@ minifloat_ieee_quantizer, binary_quantizer, ternary_quantizer, + mxfp_quantizer, + mxint_quantizer, ) # PyTorch has torch.matmul and torch.bmm for matrix multiplication @@ -430,3 +432,81 @@ def bmm_block_minifloat(x, y, config): def bmm_block_log(x, y, config): return generic_matmul_block_log(x, y, config, style="bmm") + + +def generic_matmul_mxfp(x, y, config, style="matmul"): + bypass = config.get("bypass", False) + matmul = matmul_mapping[style] + if bypass: + return matmul(x, y) + + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + y_block_size = config["weight_block_size"] + y_exp_bits = config["weight_exponent_width"] + y_frac_bits = config["weight_frac_width"] + + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + y_quantizer = partial( + mxfp_quantizer, + block_size=y_block_size, + element_exp_bits=y_exp_bits, + element_frac_bits=y_frac_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + y = y_quantizer(y) + return matmul(x, y) + + +def generic_matmul_mxint(x, y, config, style="matmul"): + bypass = config.get("bypass", False) + matmul = matmul_mapping[style] + if bypass: + return matmul(x, y) + + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + y_block_size = config["weight_block_size"] + y_element_bits = config["weight_width"] + + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + y_quantizer = partial( + mxint_quantizer, + block_size=y_block_size, + element_bits=y_element_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + y = y_quantizer(y) + return matmul(x, y) + + +def matmul_mxfp(x, y, config): + return generic_matmul_mxfp(x, y, config, "matmul") + + +def matmul_mxint(x, y, config): + return generic_matmul_mxint(x, y, config, "matmul") + + +def bmm_mxfp(x, y, config): + return generic_matmul_mxfp(x, y, config, "bmm") + + +def bmm_mxint(x, y, config): + return generic_matmul_mxint(x, y, config, "bmm") diff --git a/src/chop/nn/quantized/functional/rope.py b/src/chop/nn/quantized/functional/rope.py new file mode 100644 index 000000000..9bc90a3fb --- /dev/null +++ b/src/chop/nn/quantized/functional/rope.py @@ -0,0 +1,101 @@ +from functools import partial + +import torch +from torch import Tensor + +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer +from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim + + +def rotate_half(x: Tensor) -> Tensor: + """Rotate half the last dimension (for RoPE).""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rope(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + seq_len = q.size(-2) + cos = cos[..., :seq_len, :] + sin = sin[..., :seq_len, :] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def rope_mxfp( + q: Tensor, + k: Tensor, + cos: Tensor, + sin: Tensor, + config: dict = None, + unsqueeze_dim: int = 1, +) -> tuple[Tensor, Tensor]: + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + + cos = x_quantizer(cos) + sin = x_quantizer(sin) + return _apply_rope(q, k, cos, sin, unsqueeze_dim) + + +def rope_mxint( + q: Tensor, + k: Tensor, + cos: Tensor, + sin: Tensor, + config: dict = None, + unsqueeze_dim: int = 1, +) -> tuple[Tensor, Tensor]: + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + + cos = x_quantizer(cos) + sin = x_quantizer(sin) + return _apply_rope(q, k, cos, sin, unsqueeze_dim) + + +def rope_minifloat( + q: Tensor, + k: Tensor, + cos: Tensor, + sin: Tensor, + config: dict = None, + unsqueeze_dim: int = 1, +) -> tuple[Tensor, Tensor]: + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + minifloat_quantizer_sim, + minifloat_meta=MinifloatMeta( + exp_bits=x_exp_bits, + frac_bits=x_frac_bits, + is_finite=config.get("data_in_is_finite", True), + round_mode=config.get("data_in_round_mode", "rn"), + ), + ) + + cos = x_quantizer(cos) + sin = x_quantizer(sin) + return _apply_rope(q, k, cos, sin, unsqueeze_dim) diff --git a/src/chop/nn/quantized/functional/silu.py b/src/chop/nn/quantized/functional/silu.py new file mode 100644 index 000000000..ee9e60bad --- /dev/null +++ b/src/chop/nn/quantized/functional/silu.py @@ -0,0 +1,57 @@ +from functools import partial + +import torch +from torch import Tensor + +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer +from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim + + +def silu_mxfp(x: Tensor, config: dict = None) -> Tensor: + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + return torch.nn.functional.silu(x) + + +def silu_mxint(x: Tensor, config: dict = None) -> Tensor: + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + return torch.nn.functional.silu(x) + + +def silu_minifloat(x: Tensor, config: dict = None) -> Tensor: + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + minifloat_quantizer_sim, + minifloat_meta=MinifloatMeta( + exp_bits=x_exp_bits, + frac_bits=x_frac_bits, + is_finite=config.get("data_in_is_finite", True), + round_mode=config.get("data_in_round_mode", "rn"), + ), + ) + + x = x_quantizer(x) + return torch.nn.functional.silu(x) diff --git a/src/chop/nn/quantized/functional/softmax.py b/src/chop/nn/quantized/functional/softmax.py new file mode 100644 index 000000000..350cacb32 --- /dev/null +++ b/src/chop/nn/quantized/functional/softmax.py @@ -0,0 +1,57 @@ +from functools import partial + +import torch +from torch import Tensor + +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer +from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim + + +def softmax_mxfp(x: Tensor, config: dict = None, dim: int = -1) -> Tensor: + x_block_size = config["data_in_block_size"] + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + mxfp_quantizer, + block_size=x_block_size, + element_exp_bits=x_exp_bits, + element_frac_bits=x_frac_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + return torch.nn.functional.softmax(x.to(torch.float32), dim=dim).to(x.dtype) + + +def softmax_mxint(x: Tensor, config: dict = None, dim: int = -1) -> Tensor: + x_block_size = config["data_in_block_size"] + x_element_bits = config["data_in_width"] + + x_quantizer = partial( + mxint_quantizer, + block_size=x_block_size, + element_bits=x_element_bits, + block_dim=-1, + ) + + x = x_quantizer(x) + return torch.nn.functional.softmax(x.to(torch.float32), dim=dim).to(x.dtype) + + +def softmax_minifloat(x: Tensor, config: dict = None, dim: int = -1) -> Tensor: + x_exp_bits = config["data_in_exponent_width"] + x_frac_bits = config["data_in_frac_width"] + + x_quantizer = partial( + minifloat_quantizer_sim, + minifloat_meta=MinifloatMeta( + exp_bits=x_exp_bits, + frac_bits=x_frac_bits, + is_finite=config.get("data_in_is_finite", True), + round_mode=config.get("data_in_round_mode", "rn"), + ), + ) + + x = x_quantizer(x) + return torch.nn.functional.softmax(x.to(torch.float32), dim=dim).to(x.dtype) From 6e8f6bc58a1a243dc27a764c4aa4c61b8f7e0e33 Mon Sep 17 00:00:00 2001 From: jiaeenie Date: Sat, 21 Feb 2026 08:53:32 +0000 Subject: [PATCH 3/5] Add mx quant llama --- src/chop/nn/quantized/functional/linear.py | 8 +- src/chop/nn/quantized/modules/__init__.py | 24 +- src/chop/nn/quantized/modules/embedding.py | 61 ++++ src/chop/nn/quantized/modules/linear.py | 48 +++ .../nn/quantized/modules/llama/__init__.py | 6 +- .../nn/quantized/modules/llama/attention.py | 293 +++++++++++++++++- src/chop/nn/quantized/modules/llama/mlp.py | 37 ++- .../nn/quantized/modules/llama/rms_norm.py | 56 +++- .../nn/quantizers/_minifloat_mx/__init__.py | 4 +- .../nn/quantizers/_minifloat_mx/minifloat.py | 33 ++ .../passes/module/module_modify_helper.py | 32 +- 11 files changed, 563 insertions(+), 39 deletions(-) create mode 100644 src/chop/nn/quantized/modules/embedding.py create mode 100644 src/chop/nn/quantizers/_minifloat_mx/minifloat.py diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index 00b4cbc58..98d12d268 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -608,7 +608,7 @@ def linearMXFP( block_size=w_block_size, element_exp_bits=w_exp_bits, element_frac_bits=w_frac_bits, - block_dim=-1, + block_dim=1, ) x_quantizer = partial( mxfp_quantizer, @@ -622,7 +622,7 @@ def linearMXFP( block_size=b_block_size, element_exp_bits=b_exp_bits, element_frac_bits=b_frac_bits, - block_dim=-1, + block_dim=0, ) x = x_quantizer(x) @@ -650,7 +650,7 @@ def linearMXInt( mxint_quantizer, block_size=w_block_size, element_bits=w_element_bits, - block_dim=-1, + block_dim=1, ) x_quantizer = partial( mxint_quantizer, @@ -662,7 +662,7 @@ def linearMXInt( mxint_quantizer, block_size=b_block_size, element_bits=b_element_bits, - block_dim=-1, + block_dim=0, ) x = x_quantizer(x) diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index 1613a1c1c..36dec9a57 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -9,7 +9,18 @@ RobertaSelfOutputLSQInteger, ) -from .llama import LlamaAttentionLSQInteger, LlamaRMSNormLSQInteger, LlamaMLPLSQInteger +from .embedding import EmbeddingMXFP, EmbeddingMXInt + +from .llama import ( + LlamaAttentionLSQInteger, + LlamaRMSNormLSQInteger, + LlamaMLPLSQInteger, + LlamaAttentionMXFP, + LlamaAttentionMXInt, + LlamaMLPMXFP, + LlamaMLPMXInt, + LlamaRMSNormMinifloat, +) # from .add import AddInteger from .conv1d import ( @@ -60,6 +71,8 @@ LinearLUT, LinearLogicNets, LinearMXIntHardware, + LinearMXFP, + LinearMXInt, ) from .pool2d import ( AdaptiveAvgPool2dInteger, @@ -200,6 +213,8 @@ "linear_fixed": LinearInteger, "linear_log": LinearLog, "linear_mxint_hardware": LinearMXIntHardware, + "linear_mxfp": LinearMXFP, + "linear_mxint": LinearMXInt, "linear_block_log": LinearBlockLog, "linear_minifloat_ieee": LinearMinifloatIEEE, "linear_minifloat_denorm": LinearMinifloatDenorm, @@ -291,6 +306,8 @@ "softplus_ternary": SoftplusTernary, "batch_norm1d_fixed": BatchNorm1dInteger, "batch_norm1d_linear": BatchNorm1dInteger, + "embedding_mxfp": EmbeddingMXFP, + "embedding_mxint": EmbeddingMXInt, } quantized_bert_module_map = { @@ -309,8 +326,13 @@ quantized_llama_module_map = { "llama_self_attention_lsqinteger": LlamaAttentionLSQInteger, + "llama_self_attention_mxfp": LlamaAttentionMXFP, + "llama_self_attention_mxint": LlamaAttentionMXInt, "llama_rms_norm_lsqinteger": LlamaRMSNormLSQInteger, + "llama_rms_norm_minifloat": LlamaRMSNormMinifloat, "llama_mlp_lsqinteger": LlamaMLPLSQInteger, + "llama_mlp_mxfp": LlamaMLPMXFP, + "llama_mlp_mxint": LlamaMLPMXInt, } quantized_module_map = ( diff --git a/src/chop/nn/quantized/modules/embedding.py b/src/chop/nn/quantized/modules/embedding.py new file mode 100644 index 000000000..a625020aa --- /dev/null +++ b/src/chop/nn/quantized/modules/embedding.py @@ -0,0 +1,61 @@ +from functools import partial + +import torch +from torch import Tensor, nn + +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer + + +class EmbeddingMXFP(nn.Embedding): + """MXFP-quantized Embedding. Weight is quantized at forward time.""" + + def __init__(self, num_embeddings, embedding_dim, config=None, **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + self.config = config or {} + self.bypass = self.config.get("bypass", False) + + if not self.bypass: + self.w_quantizer = partial( + mxfp_quantizer, + block_size=self.config["weight_block_size"], + element_exp_bits=self.config["weight_exponent_width"], + element_frac_bits=self.config["weight_frac_width"], + block_dim=1, + ) + + @torch.no_grad() + def forward(self, input: Tensor) -> Tensor: + if self.bypass: + return super().forward(input) + weight = self.w_quantizer(self.weight) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse, + ) + + +class EmbeddingMXInt(nn.Embedding): + """MXInt-quantized Embedding. Weight is quantized at forward time.""" + + def __init__(self, num_embeddings, embedding_dim, config=None, **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + self.config = config or {} + self.bypass = self.config.get("bypass", False) + + if not self.bypass: + self.w_quantizer = partial( + mxint_quantizer, + block_size=self.config["weight_block_size"], + element_bits=self.config["weight_width"], + block_dim=1, + ) + + @torch.no_grad() + def forward(self, input: Tensor) -> Tensor: + if self.bypass: + return super().forward(input) + weight = self.w_quantizer(self.weight) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse, + ) diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index 5d8d389a5..d1d01e651 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -12,6 +12,8 @@ linearMinifloatDenorm, linearMinifloatIEEE, linearTernary, + linearMXFP, + linearMXInt, ) import torch from torch import Tensor @@ -810,3 +812,49 @@ def forward(self, x): return linearMXIntHardware( x, self.weight, self.bias, self.config, self.out_config ) + + +class LinearMXFP(_LinearBase): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + config=None, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + def forward(self, x): + if self.bypass: + return F.linear(x, self.weight, self.bias) + return linearMXFP(x, self.weight, self.bias, self.config) + + +class LinearMXInt(_LinearBase): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + config=None, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + def forward(self, x): + if self.bypass: + return F.linear(x, self.weight, self.bias) + return linearMXInt(x, self.weight, self.bias, self.config) diff --git a/src/chop/nn/quantized/modules/llama/__init__.py b/src/chop/nn/quantized/modules/llama/__init__.py index 289cb954c..e0b34fc23 100644 --- a/src/chop/nn/quantized/modules/llama/__init__.py +++ b/src/chop/nn/quantized/modules/llama/__init__.py @@ -1,3 +1,3 @@ -from .attention import LlamaAttentionLSQInteger -from .rms_norm import LlamaRMSNormLSQInteger -from .mlp import LlamaMLPLSQInteger +from .attention import LlamaAttentionLSQInteger, LlamaAttentionMXFP, LlamaAttentionMXInt +from .rms_norm import LlamaRMSNormLSQInteger, LlamaRMSNormMinifloat +from .mlp import LlamaMLPLSQInteger, LlamaMLPMXFP, LlamaMLPMXInt diff --git a/src/chop/nn/quantized/modules/llama/attention.py b/src/chop/nn/quantized/modules/llama/attention.py index 2b0a252ae..ef7666dd1 100644 --- a/src/chop/nn/quantized/modules/llama/attention.py +++ b/src/chop/nn/quantized/modules/llama/attention.py @@ -1,27 +1,24 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Optional, Tuple -from chop.nn.quantizers.SNN.LSQ import LSQInteger import torch -from torch import dropout, nn - -import math - +from torch import Tensor, nn, LongTensor from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, - LlamaRotaryEmbedding, apply_rotary_pos_emb, - ACT2FN, LlamaConfig, Cache, repeat_kv, - LlamaForCausalLM, - LlamaDecoderLayer, - ALL_ATTENTION_FUNCTIONS, - eager_attention_forward, + LlamaAttention, ) +from functools import partial + +from chop.nn.quantizers.SNN.LSQ import LSQInteger +from chop.nn.quantizers import mxfp_quantizer, mxint_quantizer +from chop.nn.quantized.functional.rope import rope_minifloat +from chop.nn.quantized.functional.softmax import softmax_minifloat +from chop.nn.quantized.functional.kvcache import kv_cache_mxfp, kv_cache_mxint + import logging -from torch.utils.tensorboard import SummaryWriter logger = logging.getLogger(__name__) @@ -128,3 +125,271 @@ def forward( attn_output = self.o_quant(attn_output) return attn_output, attn_weights + + +class LlamaAttentionMXFP(LlamaAttention): + """MXFP-quantized LlamaAttention. + """ + + def __init__(self, config, layer_idx, q_config: dict = None): + super().__init__(config, layer_idx) + q_config = q_config or {} + self.qk_config = q_config.get("qk_matmul", {}) + self.av_config = q_config.get("av_matmul", {}) + self.rope_config = q_config.get("rope", {}) + self.softmax_config = q_config.get("softmax", {}) + self.kv_cache_config = q_config.get("kv_cache", {}) + self.qk_bypass = self.qk_config.get("bypass", False) + self.av_bypass = self.av_config.get("bypass", False) + self.rope_bypass = self.rope_config.get("bypass", False) + self.softmax_bypass = self.softmax_config.get("bypass", False) + self.kv_cache_bypass = self.kv_cache_config.get("bypass", False) + + def forward( + self, + hidden_states: Tensor, + position_embeddings: Tuple[Tensor, Tensor], + attention_mask: Optional[Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[LongTensor] = None, + **kwargs, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + if not self.rope_bypass: + query_states, key_states = rope_minifloat( + query_states, key_states, cos, sin, self.rope_config, + ) + else: + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + if not self.kv_cache_bypass: + key_states, value_states = kv_cache_mxfp( + key_states, value_states, self.kv_cache_config, + ) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs, + ) + + attn_output, attn_weights = _eager_attention_forward_mxfp( + self, query_states, key_states, value_states, attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + qk_bypass=self.qk_bypass, + qk_config=self.qk_config, + av_bypass=self.av_bypass, + av_config=self.av_config, + softmax_bypass=self.softmax_bypass, + softmax_config=self.softmax_config, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + @classmethod + def from_attention(cls, attention: LlamaAttention, q_config: dict = None): + new_attn = cls( + config=attention.config, + layer_idx=attention.layer_idx, + q_config=q_config, + ) + device, dtype = next(attention.parameters()).device, next(attention.parameters()).dtype + new_attn = new_attn.to(dtype=dtype, device=device) + new_attn.load_state_dict(attention.state_dict(), strict=True) + return new_attn + + +class LlamaAttentionMXInt(LlamaAttention): + """MXInt-quantized LlamaAttention. + """ + + def __init__(self, config, layer_idx, q_config: dict = None): + super().__init__(config, layer_idx) + q_config = q_config or {} + self.qk_config = q_config.get("qk_matmul", {}) + self.av_config = q_config.get("av_matmul", {}) + self.rope_config = q_config.get("rope", {}) + self.softmax_config = q_config.get("softmax", {}) + self.kv_cache_config = q_config.get("kv_cache", {}) + self.qk_bypass = self.qk_config.get("bypass", False) + self.av_bypass = self.av_config.get("bypass", False) + self.rope_bypass = self.rope_config.get("bypass", False) + self.softmax_bypass = self.softmax_config.get("bypass", False) + self.kv_cache_bypass = self.kv_cache_config.get("bypass", False) + + def forward( + self, + hidden_states: Tensor, + position_embeddings: Tuple[Tensor, Tensor], + attention_mask: Optional[Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[LongTensor] = None, + **kwargs, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + if not self.rope_bypass: + query_states, key_states = rope_minifloat( + query_states, key_states, cos, sin, self.rope_config, + ) + else: + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + if not self.kv_cache_bypass: + key_states, value_states = kv_cache_mxint( + key_states, value_states, self.kv_cache_config, + ) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs, + ) + + attn_output, attn_weights = _eager_attention_forward_mxint( + self, query_states, key_states, value_states, attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + qk_bypass=self.qk_bypass, + qk_config=self.qk_config, + av_bypass=self.av_bypass, + av_config=self.av_config, + softmax_bypass=self.softmax_bypass, + softmax_config=self.softmax_config, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + @classmethod + def from_attention(cls, attention: LlamaAttention, q_config: dict = None): + new_attn = cls( + config=attention.config, + layer_idx=attention.layer_idx, + q_config=q_config, + ) + device, dtype = next(attention.parameters()).device, next(attention.parameters()).dtype + new_attn = new_attn.to(dtype=dtype, device=device) + new_attn.load_state_dict(attention.state_dict(), strict=True) + return new_attn + + +def _eager_attention_forward_mxfp( + module, query, key, value, attention_mask, scaling, + dropout=0.0, qk_bypass=False, qk_config=None, + av_bypass=False, av_config=None, + softmax_bypass=False, softmax_config=None, **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + if not qk_bypass: + q_quantizer = partial( + mxfp_quantizer, + block_size=qk_config["data_in_block_size"], + element_exp_bits=qk_config["data_in_exponent_width"], + element_frac_bits=qk_config["data_in_frac_width"], + block_dim=-1, + ) + query = q_quantizer(query) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if not softmax_bypass: + attn_weights = softmax_minifloat(attn_weights, softmax_config, dim=-1) + else: + attn_weights = nn.functional.softmax( + attn_weights.to(torch.float32), dim=-1, + ).to(attn_weights.dtype) + + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training, + ) + + if not av_bypass: + a_quantizer = partial( + mxfp_quantizer, + block_size=av_config["data_in_block_size"], + element_exp_bits=av_config["data_in_exponent_width"], + element_frac_bits=av_config["data_in_frac_width"], + block_dim=-1, + ) + attn_weights = a_quantizer(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def _eager_attention_forward_mxint( + module, query, key, value, attention_mask, scaling, + dropout=0.0, qk_bypass=False, qk_config=None, + av_bypass=False, av_config=None, + softmax_bypass=False, softmax_config=None, **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + if not qk_bypass: + q_quantizer = partial( + mxint_quantizer, + block_size=qk_config["data_in_block_size"], + element_bits=qk_config["data_in_width"], + block_dim=-1, + ) + query = q_quantizer(query) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if not softmax_bypass: + attn_weights = softmax_minifloat(attn_weights, softmax_config, dim=-1) + else: + attn_weights = nn.functional.softmax( + attn_weights.to(torch.float32), dim=-1, + ).to(attn_weights.dtype) + + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training, + ) + + if not av_bypass: + a_quantizer = partial( + mxint_quantizer, + block_size=av_config["data_in_block_size"], + element_bits=av_config["data_in_width"], + block_dim=-1, + ) + attn_weights = a_quantizer(attn_weights) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights diff --git a/src/chop/nn/quantized/modules/llama/mlp.py b/src/chop/nn/quantized/modules/llama/mlp.py index 94a24bbbf..533504da5 100644 --- a/src/chop/nn/quantized/modules/llama/mlp.py +++ b/src/chop/nn/quantized/modules/llama/mlp.py @@ -1,9 +1,8 @@ import torch -from torch import nn -import math +from torch import nn, Tensor from chop.nn.quantizers.SNN.LSQ import LSQInteger -from typing import Optional, Tuple +from chop.nn.quantized.functional.silu import silu_minifloat from transformers.models.llama.modeling_llama import LlamaMLP, ACT2FN @@ -44,3 +43,35 @@ def forward(self, x): down_proj = self.down_dense_quan(down_proj) return down_proj + + +class LlamaMLPMXFP(LlamaMLP): + """MXFP-quantized LlamaMLP. SiLU uses minifloat quantization.""" + + def __init__(self, config, layer_idx=None, q_config: dict = None): + super().__init__(config) + self.layer_idx = layer_idx + self.q_config = q_config or {} + self.bypass = self.q_config.get("bypass", False) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return super().forward(x) + x = silu_minifloat(self.gate_proj(x), self.q_config) * self.up_proj(x) + return self.down_proj(x) + + +class LlamaMLPMXInt(LlamaMLP): + """MXInt-quantized LlamaMLP. SiLU uses minifloat quantization.""" + + def __init__(self, config, layer_idx=None, q_config: dict = None): + super().__init__(config) + self.layer_idx = layer_idx + self.q_config = q_config or {} + self.bypass = self.q_config.get("bypass", False) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return super().forward(x) + x = silu_minifloat(self.gate_proj(x), self.q_config) * self.up_proj(x) + return self.down_proj(x) diff --git a/src/chop/nn/quantized/modules/llama/rms_norm.py b/src/chop/nn/quantized/modules/llama/rms_norm.py index 1095b2882..c459a26d3 100644 --- a/src/chop/nn/quantized/modules/llama/rms_norm.py +++ b/src/chop/nn/quantized/modules/llama/rms_norm.py @@ -1,9 +1,10 @@ +from functools import partial + import torch -from torch import nn -import math +from torch import Tensor, nn from chop.nn.quantizers.SNN.LSQ import LSQInteger -from typing import Optional, Tuple +from chop.nn.quantizers._minifloat_mx import MinifloatMeta, minifloat_quantizer_sim from transformers.models.llama.modeling_llama import LlamaRMSNorm @@ -28,3 +29,52 @@ def forward(self, hidden_states): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LlamaRMSNormMinifloat(LlamaRMSNorm): + """Minifloat-quantized LlamaRMSNorm. Weight and input use minifloat quantization at forward time.""" + + def __init__(self, config=None, layer_idx=None, q_config: dict = None): + super().__init__(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + self.q_config = q_config or {} + self.variance_epsilon = config.rms_norm_eps + self.bypass = self.q_config.get("bypass", False) + self.weight_bypass = self.q_config.get("weight_bypass", False) + self.data_in_bypass = self.q_config.get("data_in_bypass", False) + + if not self.bypass and not self.weight_bypass: + self.w_quantizer = partial( + minifloat_quantizer_sim, + minifloat_meta=MinifloatMeta( + exp_bits=self.q_config["weight_exponent_width"], + frac_bits=self.q_config["weight_frac_width"], + is_finite=self.q_config.get("weight_is_finite", True), + round_mode=self.q_config.get("weight_round_mode", "rn"), + ), + ) + else: + self.w_quantizer = None + + if not self.bypass and not self.data_in_bypass: + self.x_quantizer = partial( + minifloat_quantizer_sim, + minifloat_meta=MinifloatMeta( + exp_bits=self.q_config["data_in_exponent_width"], + frac_bits=self.q_config["data_in_frac_width"], + is_finite=self.q_config.get("data_in_is_finite", True), + round_mode=self.q_config.get("data_in_round_mode", "rn"), + ), + ) + else: + self.x_quantizer = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + if self.x_quantizer is not None: + hidden_states = self.x_quantizer(hidden_states) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + weight = self.w_quantizer(self.weight) if self.w_quantizer is not None else self.weight + return weight * hidden_states.to(input_dtype) diff --git a/src/chop/nn/quantizers/_minifloat_mx/__init__.py b/src/chop/nn/quantizers/_minifloat_mx/__init__.py index 053f7f8f0..b25cbfae2 100644 --- a/src/chop/nn/quantizers/_minifloat_mx/__init__.py +++ b/src/chop/nn/quantizers/_minifloat_mx/__init__.py @@ -1,15 +1,17 @@ """ Internal minifloat module for MX-format quantizers. -This is used internally by MXFP. +This is used internally by MXFP and by quantized functions (softmax, silu, rope). """ from .meta import MinifloatMeta, MinifloatTensorMeta from .fake import extract_minifloat_component, compose_minifloat_component +from .minifloat import minifloat_quantizer_sim __all__ = [ "MinifloatMeta", "MinifloatTensorMeta", "extract_minifloat_component", "compose_minifloat_component", + "minifloat_quantizer_sim", ] diff --git a/src/chop/nn/quantizers/_minifloat_mx/minifloat.py b/src/chop/nn/quantizers/_minifloat_mx/minifloat.py new file mode 100644 index 000000000..c722ffc79 --- /dev/null +++ b/src/chop/nn/quantizers/_minifloat_mx/minifloat.py @@ -0,0 +1,33 @@ +""" +Minifloat quantize-dequantize simulation. +""" + +import torch +from torch import Tensor + +from .meta import MinifloatMeta, MinifloatTensorMeta +from .fake import extract_minifloat_component, compose_minifloat_component + + +def minifloat_quantizer_sim( + tensor: Tensor, + minifloat_meta: MinifloatMeta, + output_dtype: torch.dtype | None = None, +) -> Tensor: + """ + Quantize and dequantize a tensor using minifloat format. + + Args: + tensor: Input tensor to quantize + minifloat_meta: Minifloat format specification + output_dtype: Desired output dtype (default: same as input) + + Returns: + Dequantized tensor + """ + ori_dtype = tensor.dtype + element = extract_minifloat_component(tensor, minifloat_meta) + + return compose_minifloat_component( + element, minifloat_meta, output_dtype=output_dtype or ori_dtype + ) diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index b8bfe6864..7a3cf7aa7 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -176,16 +176,28 @@ def instantiate_conv2d(module, postfix, module_map, additional_module_args): def instantiate_embedding(module, postfix, module_map, additional_module_args): embedding_cls = module_map[f"embedding_{postfix}"] - embedding = embedding_cls( - num_embeddings=module.num_embeddings, - embedding_dim=module.embedding_dim, - padding_idx=module.padding_idx, - max_norm=module.max_norm, - norm_type=module.norm_type, - scale_grad_by_freq=module.scale_grad_by_freq, - sparse=module.sparse, - **additional_module_args, - ) + if "config" in inspect.signature(embedding_cls.__init__).parameters: + embedding = embedding_cls( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + config=additional_module_args, + ) + else: + embedding = embedding_cls( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + **additional_module_args, + ) return embedding From 3e64a124e7177979e4e9b0d2138622686de5c31e Mon Sep 17 00:00:00 2001 From: jiaeenie Date: Mon, 23 Feb 2026 11:35:15 +0000 Subject: [PATCH 4/5] Refactor linear to quantize in load_state_dict for PTQ --- src/chop/nn/quantized/functional/linear.py | 85 ---------------- src/chop/nn/quantized/modules/linear.py | 113 +++++++++++++++++++-- 2 files changed, 105 insertions(+), 93 deletions(-) diff --git a/src/chop/nn/quantized/functional/linear.py b/src/chop/nn/quantized/functional/linear.py index 98d12d268..a8cf68298 100644 --- a/src/chop/nn/quantized/functional/linear.py +++ b/src/chop/nn/quantized/functional/linear.py @@ -584,88 +584,3 @@ def linearMXIntHardware( out = out_quantizer(out) return out - -def linearMXFP( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, -): - w_block_size = config["weight_block_size"] - w_exp_bits = config["weight_exponent_width"] - w_frac_bits = config["weight_frac_width"] - - x_block_size = config["data_in_block_size"] - x_exp_bits = config["data_in_exponent_width"] - x_frac_bits = config["data_in_frac_width"] - - b_block_size = config["bias_block_size"] - b_exp_bits = config["bias_exponent_width"] - b_frac_bits = config["bias_frac_width"] - - w_quantizer = partial( - mxfp_quantizer, - block_size=w_block_size, - element_exp_bits=w_exp_bits, - element_frac_bits=w_frac_bits, - block_dim=1, - ) - x_quantizer = partial( - mxfp_quantizer, - block_size=x_block_size, - element_exp_bits=x_exp_bits, - element_frac_bits=x_frac_bits, - block_dim=-1, - ) - b_quantizer = partial( - mxfp_quantizer, - block_size=b_block_size, - element_exp_bits=b_exp_bits, - element_frac_bits=b_frac_bits, - block_dim=0, - ) - - x = x_quantizer(x) - weight = w_quantizer(weight) - bias = b_quantizer(bias) if bias is not None else None - return F.linear(x, weight, bias) - - -def linearMXInt( - x: Tensor, - weight: Tensor, - bias: Tensor = None, - config: dict = None, -): - w_block_size = config["weight_block_size"] - w_element_bits = config["weight_width"] - - x_block_size = config["data_in_block_size"] - x_element_bits = config["data_in_width"] - - b_block_size = config["bias_block_size"] - b_element_bits = config["bias_width"] - - w_quantizer = partial( - mxint_quantizer, - block_size=w_block_size, - element_bits=w_element_bits, - block_dim=1, - ) - x_quantizer = partial( - mxint_quantizer, - block_size=x_block_size, - element_bits=x_element_bits, - block_dim=-1, - ) - b_quantizer = partial( - mxint_quantizer, - block_size=b_block_size, - element_bits=b_element_bits, - block_dim=0, - ) - - x = x_quantizer(x) - weight = w_quantizer(weight) - bias = b_quantizer(bias) if bias is not None else None - return F.linear(x, weight, bias) diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index d1d01e651..48761ed8b 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -12,8 +12,6 @@ linearMinifloatDenorm, linearMinifloatIEEE, linearTernary, - linearMXFP, - linearMXInt, ) import torch from torch import Tensor @@ -35,6 +33,8 @@ binary_quantizer, ternary_quantizer, mxint_hardware, + mxint_quantizer, + mxfp_quantizer, ) # LUTNet @@ -815,6 +815,7 @@ def forward(self, x): class LinearMXFP(_LinearBase): + # NOTE: backward is not supported — inference only (PTQ) def __init__( self, in_features: int, @@ -828,16 +829,68 @@ def __init__( assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) - if self.bypass: - return + self.gptq = config.get("gptq", False) + self.clip_search = config.get("clip_search", False) + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Load pretrained weights, then quantize them in place.""" + result = super().load_state_dict(state_dict, strict=strict, assign=assign) + + if self.bypass or self.gptq: + return result + + # Quantize weight + w_block_size = self.config["weight_block_size"] + w_exp_bits = self.config["weight_exponent_width"] + w_frac_bits = self.config["weight_frac_width"] + self.weight.data.copy_( + mxfp_quantizer( + self.weight.data, + block_size=w_block_size, + element_exp_bits=w_exp_bits, + element_frac_bits=w_frac_bits, + block_dim=1, + ) + ) + + # Quantize bias + b_block_size = self.config.get("bias_block_size") + b_exp_bits = self.config.get("bias_exponent_width") + b_frac_bits = self.config.get("bias_frac_width") + if self.bias is not None and b_block_size is not None and b_exp_bits is not None: + self.bias.data.copy_( + mxfp_quantizer( + self.bias.data, + block_size=b_block_size, + element_exp_bits=b_exp_bits, + element_frac_bits=b_frac_bits, + block_dim=0, + ) + ) + + return result + @torch.no_grad() def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) - return linearMXFP(x, self.weight, self.bias, self.config) + + # Only quantize activations; weights/bias already quantized in load_state_dict + x_block_size = self.config.get("data_in_block_size") + x_exp_bits = self.config.get("data_in_exponent_width") + x_frac_bits = self.config.get("data_in_frac_width") + if x_block_size is not None and x_exp_bits is not None: + x = mxfp_quantizer( + x, block_size=x_block_size, + element_exp_bits=x_exp_bits, element_frac_bits=x_frac_bits, + block_dim=-1, + ) + + return F.linear(x, self.weight, self.bias) class LinearMXInt(_LinearBase): + # NOTE: backward is not supported — inference only (PTQ) def __init__( self, in_features: int, @@ -851,10 +904,54 @@ def __init__( assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) - if self.bypass: - return + self.gptq = config.get("gptq", False) + self.clip_search = config.get("clip_search", False) + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Load pretrained weights, then quantize them in place.""" + result = super().load_state_dict(state_dict, strict=strict, assign=assign) + + if self.bypass or self.gptq: + return result + + # Quantize weight + w_block_size = self.config["weight_block_size"] + w_element_bits = self.config["weight_width"] + self.weight.data.copy_( + mxint_quantizer( + self.weight.data, + block_size=w_block_size, + element_bits=w_element_bits, + block_dim=1, + quantile_search=self.clip_search, + ) + ) + # Quantize bias + b_block_size = self.config.get("bias_block_size") + b_element_bits = self.config.get("bias_width") + if self.bias is not None and b_block_size is not None and b_element_bits is not None: + self.bias.data.copy_( + mxint_quantizer( + self.bias.data, + block_size=b_block_size, + element_bits=b_element_bits, + block_dim=0, + quantile_search=self.clip_search, + ) + ) + + return result + + @torch.no_grad() def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) - return linearMXInt(x, self.weight, self.bias, self.config) + + # Only quantize activations; weights/bias already quantized in load_state_dict + x_block_size = self.config.get("data_in_block_size") + x_element_bits = self.config.get("data_in_width") + if x_block_size is not None and x_element_bits is not None: + x = mxint_quantizer(x, block_size=x_block_size, element_bits=x_element_bits, block_dim=-1) + + return F.linear(x, self.weight, self.bias) From 6f4cbf197d0a51c51f3c7bb2154cd2dee8cbca46 Mon Sep 17 00:00:00 2001 From: jiaeenie Date: Mon, 23 Feb 2026 12:01:56 +0000 Subject: [PATCH 5/5] Support GPTQ pre-pass to quantize_module_transform_pass --- .../passes/module/transforms/gptq/__init__.py | 3 + .../module/transforms/gptq/checkpoint.py | 122 ++++++++++ .../module/transforms/gptq/data_utils.py | 102 ++++++++ .../passes/module/transforms/gptq/gptq.py | 115 +++++++++ .../transforms/gptq/quantize_dispatch.py | 82 +++++++ src/chop/passes/module/transforms/gptq/run.py | 220 ++++++++++++++++++ .../passes/module/transforms/gptq/utils.py | 41 ++++ .../module/transforms/quantize/quantize.py | 7 + 8 files changed, 692 insertions(+) create mode 100644 src/chop/passes/module/transforms/gptq/__init__.py create mode 100644 src/chop/passes/module/transforms/gptq/checkpoint.py create mode 100644 src/chop/passes/module/transforms/gptq/data_utils.py create mode 100644 src/chop/passes/module/transforms/gptq/gptq.py create mode 100644 src/chop/passes/module/transforms/gptq/quantize_dispatch.py create mode 100644 src/chop/passes/module/transforms/gptq/run.py create mode 100644 src/chop/passes/module/transforms/gptq/utils.py diff --git a/src/chop/passes/module/transforms/gptq/__init__.py b/src/chop/passes/module/transforms/gptq/__init__.py new file mode 100644 index 000000000..5c40ffc1f --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/__init__.py @@ -0,0 +1,3 @@ +from .run import run_gptq + +__all__ = ["run_gptq"] diff --git a/src/chop/passes/module/transforms/gptq/checkpoint.py b/src/chop/passes/module/transforms/gptq/checkpoint.py new file mode 100644 index 000000000..d37399a48 --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/checkpoint.py @@ -0,0 +1,122 @@ +import json +import logging +import os +from pathlib import Path + +from safetensors.torch import save_file, load_file + + +def save_layer_checkpoint(model, layer_idx, checkpoint_dir, model_name="quantized_model"): + if checkpoint_dir is None: + return + + checkpoint_path = Path(checkpoint_dir) + checkpoint_path.mkdir(parents=True, exist_ok=True) + + layer_checkpoint_file = checkpoint_path / f"{model_name}_layer_{layer_idx}.safetensors" + + if layer_checkpoint_file.exists() and layer_checkpoint_file.is_dir(): + import shutil + shutil.rmtree(layer_checkpoint_file) + logging.info(f"Removed existing directory: {layer_checkpoint_file}") + + logging.info(f"Saving layer {layer_idx} checkpoint to {layer_checkpoint_file}") + + try: + layer_state_dict = {} + layer_prefix = f"model.layers.{layer_idx}." + + for name, param in model.named_parameters(): + if name.startswith(layer_prefix): + relative_name = name[len(layer_prefix):] + layer_state_dict[relative_name] = param.detach().cpu() + + if not layer_state_dict: + logging.warning(f"No parameters found for layer {layer_idx} with prefix {layer_prefix}") + return + + save_file(layer_state_dict, str(layer_checkpoint_file)) + + metadata = { + "layer_idx": layer_idx, + "total_layers": len(model.model.layers), + "checkpoint_file": str(layer_checkpoint_file), + "model_name": model_name, + "num_parameters": len(layer_state_dict), + "parameter_names": list(layer_state_dict.keys()) + } + + metadata_file = checkpoint_path / f"{model_name}_layer_{layer_idx}_metadata.json" + with open(metadata_file, 'w') as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Layer {layer_idx} checkpoint saved successfully ({len(layer_state_dict)} parameters)") + + except Exception as e: + logging.error(f"Failed to save layer {layer_idx} checkpoint: {e}") + + +def detect_quantized_layers(checkpoint_dir, model_name="quantized_model"): + if checkpoint_dir is None or not os.path.exists(checkpoint_dir): + return {} + + checkpoint_path = Path(checkpoint_dir) + checkpoints = list(checkpoint_path.glob(f"{model_name}_layer_*.safetensors")) + + quantized_layers = {} + for checkpoint in checkpoints: + try: + layer_idx = int(checkpoint.stem.split('_layer_')[-1]) + quantized_layers[layer_idx] = str(checkpoint) + except ValueError: + continue + + return quantized_layers + + +def load_layer_checkpoint(model, layer_idx, checkpoint_file): + if not os.path.exists(checkpoint_file): + logging.error(f"Layer checkpoint file {checkpoint_file} not found") + return False + + try: + layer_state_dict = load_file(checkpoint_file) + + layer_prefix = f"model.layers.{layer_idx}." + model_state_dict = {} + + for param_name, param_value in layer_state_dict.items(): + full_param_name = layer_prefix + param_name + model_state_dict[full_param_name] = param_value + + model.load_state_dict(model_state_dict, strict=False) + + logging.info(f"Successfully loaded layer {layer_idx} from checkpoint") + return True + + except Exception as e: + logging.error(f"Failed to load layer {layer_idx} checkpoint: {e}") + return False + + +def auto_load_quantized_layers(model, checkpoint_dir, model_name="quantized_model"): + quantized_layers = detect_quantized_layers(checkpoint_dir, model_name) + + if not quantized_layers: + logging.info("No quantized layer checkpoints found") + return -1 + + loaded_count = 0 + max_layer_idx = -1 + + for layer_idx in sorted(quantized_layers.keys()): + checkpoint_file = quantized_layers[layer_idx] + if load_layer_checkpoint(model, layer_idx, checkpoint_file): + loaded_count += 1 + max_layer_idx = layer_idx + else: + logging.warning(f"Failed to load layer {layer_idx}, stopping auto-load") + break + + logging.info(f"Auto-loaded {loaded_count} quantized layers (up to layer {max_layer_idx})") + return max_layer_idx diff --git a/src/chop/passes/module/transforms/gptq/data_utils.py b/src/chop/passes/module/transforms/gptq/data_utils.py new file mode 100644 index 000000000..04912798c --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/data_utils.py @@ -0,0 +1,102 @@ +import datasets +import random +import transformers + + +def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False): + if hf_token is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) + else: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) + + if eval_mode: + testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + return testenc + else: + traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader + + +def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False): + if hf_token is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) + else: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) + + if eval_mode: + valdata = datasets.load_dataset( + 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + valenc = TokenizerWrapper(valenc) + return valenc + else: + traindata = datasets.load_dataset( + 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader + + +def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False): + if hf_token is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False) + else: + tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token) + + if eval_mode: + testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test') + testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + return testenc + else: + traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train') + trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader + + +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False +): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode) + if 'ptb' in name: + return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode) + if 'c4' in name: + return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode) diff --git a/src/chop/passes/module/transforms/gptq/gptq.py b/src/chop/passes/module/transforms/gptq/gptq.py new file mode 100644 index 000000000..9419f3ef3 --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/gptq.py @@ -0,0 +1,115 @@ +""" +GPTQ quantization algorithm. + +Ported from Coprocessor_for_Llama/acc_simulator/gptq/gptq.py, +adapted to use Mase config dicts instead of Meta classes. +""" + +import math + +import torch +import tqdm + +from .quantize_dispatch import quantize_tensor +from .utils import cleanup_memory + + +class GPTQ: + + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterquant( + self, activation, fmt, weight_config, percdamp=.01, + cali_batch_size=32, layer_name=None, quant_search=True, + ): + """ + Run GPTQ block-wise quantization. + + Args: + activation: Activation tensor for calibration (or None). + fmt: "mxfp" or "mxint". + weight_config: Mase-style config dict with weight_block_size, etc. + percdamp: Dampening percentage for Hessian diagonal. + cali_batch_size: Batch size for quantile search calibration. + layer_name: Name for progress bar. + quant_search: Enable quantile search. + """ + W = self.layer.weight.data.clone() + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + blocksize = weight_config["weight_block_size"] + for i1 in tqdm.tqdm( + range(0, self.columns, blocksize), + desc=f"Quantizing blocks {layer_name}", + disable=False, + ): + i2 = min(i1 + blocksize, self.columns) + + W1 = W[:, i1:i2].clone() + + if activation is not None: + Act1 = activation[:, :, i1:i2].clone() + Q1 = quantize_tensor( + W1, block_dim=1, fmt=fmt, config=weight_config, + quantile_search=quant_search, act_tensor=Act1, + cali_batch_size=cali_batch_size, + ) + else: + Q1 = quantize_tensor( + W1, block_dim=1, fmt=fmt, config=weight_config, + quantile_search=quant_search, + ) + + Hinv1 = Hinv[i1:i2, i1:i2] + Err1 = (W1 - Q1) / torch.diag(Hinv1).unsqueeze(0) + Losses1 = ((W1 - Q1) ** 2) / (torch.diag(Hinv1) ** 2).unsqueeze(0) + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + assert Q.shape == W.shape, f"Shape mismatch: {Q.shape} != {W.shape}" + + return Q + + def free(self): + self.H = None + self.Losses = None + torch.cuda.empty_cache() + cleanup_memory(verbos=False) diff --git a/src/chop/passes/module/transforms/gptq/quantize_dispatch.py b/src/chop/passes/module/transforms/gptq/quantize_dispatch.py new file mode 100644 index 000000000..ee4cfb694 --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/quantize_dispatch.py @@ -0,0 +1,82 @@ +""" +Dispatch quantization to mxfp_quantizer_sim or mxint_quantizer_sim +based on format string, using Mase-style config dicts. +""" + +import torch +from torch import Tensor + +from chop.nn.quantizers.mxfp.mxfp import mxfp_quantizer_sim +from chop.nn.quantizers.mxfp.meta import MXFPMeta +from chop.nn.quantizers.mxint.mxint import mxint_quantizer_sim +from chop.nn.quantizers.mxint.meta import MXIntMeta + + +def _build_mxfp_meta(config: dict) -> MXFPMeta: + return MXFPMeta( + block_size=config["weight_block_size"], + scale_exp_bits=8, + element_exp_bits=config["weight_exponent_width"], + element_frac_bits=config["weight_frac_width"], + element_is_finite=True, + round_mode="rn", + ) + + +def _build_mxint_meta(config: dict) -> MXIntMeta: + return MXIntMeta( + block_size=config["weight_block_size"], + scale_bits=8, + element_bits=config["weight_width"], + ) + + +def quantize_tensor( + input: Tensor, + block_dim: int, + fmt: str, + config: dict, + quantile_search: bool, + act_tensor: Tensor | None = None, + dtype: torch.dtype | None = None, + cali_batch_size: int = 32, +) -> Tensor: + """ + Quantize a tensor using Mase quantizers, dispatching by format string. + + Args: + input: Weight tensor to quantize. + block_dim: Dimension for block quantization. + fmt: "mxfp" or "mxint". + config: Mase-style weight config dict with keys like + weight_block_size, weight_exponent_width, weight_frac_width (mxfp) + or weight_block_size, weight_width (mxint). + quantile_search: Enable percentile-based clipping search. + act_tensor: Optional activation tensor for calibration. + dtype: Output dtype. + cali_batch_size: Batch size for calibration search. + """ + if fmt == "mxfp": + meta = _build_mxfp_meta(config) + return mxfp_quantizer_sim( + input, + block_dim=block_dim, + mxfp_meta=meta, + act_tensor=act_tensor, + dtype=dtype, + quantile_search=quantile_search, + cali_batch_size=cali_batch_size, + ) + elif fmt == "mxint": + meta = _build_mxint_meta(config) + return mxint_quantizer_sim( + input, + block_dim=block_dim, + mxint_meta=meta, + act_tensor=act_tensor, + dtype=dtype, + quantile_search=quantile_search, + cali_batch_size=cali_batch_size, + ) + else: + raise ValueError(f"Unsupported GPTQ format: {fmt}. Use 'mxfp' or 'mxint'.") diff --git a/src/chop/passes/module/transforms/gptq/run.py b/src/chop/passes/module/transforms/gptq/run.py new file mode 100644 index 000000000..47100fad8 --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/run.py @@ -0,0 +1,220 @@ +""" +Main GPTQ orchestration: run_gptq(network, gptq_config). + +Ported from Coprocessor_for_Llama/acc_simulator/gptq/quant.py, +adapted to use Mase config dicts and write quantized weights +back in-place to nn.Linear modules (no module replacement here). +""" + +import logging + +import torch +import torch.nn as nn + +from .gptq import GPTQ +from .utils import find_qlayers, cleanup_memory +from .data_utils import get_loaders +from .checkpoint import save_layer_checkpoint, auto_load_quantized_layers + + +@torch.no_grad() +def run_gptq(network, gptq_config): + """ + Run GPTQ weight optimization on all nn.Linear layers in decoder blocks. + + Quantized weights are written back in-place to the existing nn.Linear + modules so that the subsequent module-replacement pass can pick them up. + + Args: + network: HuggingFace causal-LM model (e.g. LlamaForCausalLM). + gptq_config: Dict with keys: + model_name: str - HF model name (for tokenizer). + device: str - e.g. "cuda:0". + dataset: str - "wikitext2" | "c4" | "ptb". + nsamples: int - calibration samples (default 128). + seqlen: int - sequence length (default 2048). + format: str - "mxfp" | "mxint". + weight_config: dict - Mase-style weight config, e.g. + {"weight_block_size": 32, "weight_exponent_width": 2, "weight_frac_width": 1} + quantile_search: bool (default True). + clip_search_y: bool (default False). + cali_batch_size: int (default 32). + checkpoint_dir: str | None. + hf_token: str | None. + + Returns: + network with GPTQ-optimized weights (still nn.Linear modules). + """ + logging.info('-----GPTQ Quantization-----') + + model_name = gptq_config["model_name"] + dev = gptq_config.get("device", "cuda:0") + dataset = gptq_config.get("dataset", "wikitext2") + nsamples = gptq_config.get("nsamples", 128) + seqlen = gptq_config.get("seqlen", 2048) + fmt = gptq_config["format"] + weight_config = gptq_config["weight_config"] + quantile_search = gptq_config.get("quantile_search", True) + clip_search_y = gptq_config.get("clip_search_y", False) + cali_batch_size = gptq_config.get("cali_batch_size", 32) + checkpoint_dir = gptq_config.get("checkpoint_dir", None) + hf_token = gptq_config.get("hf_token", None) + max_layers = gptq_config.get("max_layers", None) + + # Handle checkpoint resuming + start_layer = 0 + if checkpoint_dir is not None: + max_quantized_layer = auto_load_quantized_layers(network, checkpoint_dir) + if max_quantized_layer >= 0: + start_layer = max_quantized_layer + 1 + logging.info(f"Resuming GPTQ from layer {start_layer}") + + if start_layer == len(network.model.layers): + logging.info("All layers already quantized, skipping GPTQ") + return network + + # Load calibration data + dataloader = get_loaders( + dataset, nsamples=nsamples, seed=0, seqlen=seqlen, + model=model_name, hf_token=hf_token, + ) + + # Disable kv cache + use_cache = network.config.use_cache + network.config.use_cache = False + + layers = network.model.layers + + # Move embedding + norm + rope to device + network.model.embed_tokens = network.model.embed_tokens.to(dev) + network.model.norm = network.model.norm.to(dev) + rope = network.model.rotary_emb.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(network.parameters())).dtype + + inps = torch.zeros( + (nsamples, seqlen, network.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + network(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + sequential = [ + ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], + ['self_attn.o_proj'], + ['mlp.up_proj', 'mlp.gate_proj'], + ['mlp.down_proj'], + ] + + end_layer = len(layers) if max_layers is None else min(start_layer + max_layers, len(layers)) + logging.info(f"GPTQ: quantizing layers {start_layer} to {end_layer - 1} (of {len(layers)} total)") + + for i in range(start_layer, end_layer): + print(f'\nLayer {i}:', flush=True, end=' ') + layer = layers[i].to(dev) + full = find_qlayers(layer, layers=[torch.nn.Linear]) + + for names in sequential: + subset = {n: full[n] for n in names} + + gptq = {} + for name in subset: + print(f'{name}', end=' ', flush=True) + gptq[name] = GPTQ(subset[name]) + + pre_act = [] + + def make_pre_hook(): + def pre_hook(_, inp): + pre_act.append(inp[0]) + return pre_hook + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + handles.append(subset[name].register_forward_pre_hook(make_pre_hook())) + + for j in range(nsamples): + x = inps[j].unsqueeze(0) + cos, sin = rope(x, position_ids) + outs[j] = layer( + x, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + )[0] + + pre_act = torch.cat(pre_act, dim=0) + + for h in handles: + h.remove() + + for name in subset: + quantized_w = gptq[name].fasterquant( + activation=pre_act if clip_search_y else None, + fmt=fmt, + weight_config=weight_config, + percdamp=0.01, + cali_batch_size=cali_batch_size, + layer_name=f"layers{i}.{name}", + quant_search=quantile_search, + ) + + # Write quantized weights back in-place + assert quantized_w.shape == gptq[name].layer.weight.shape + gptq[name].layer.weight.data.copy_(quantized_w) + gptq[name].free() + + # Forward pass with quantized weights to get inputs for next layer + for j in range(nsamples): + x = inps[j].unsqueeze(0) + cos, sin = network.model.rotary_emb(x, position_ids) + outs[j] = layer( + x, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + )[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + if checkpoint_dir is not None: + save_layer_checkpoint(network, i, checkpoint_dir) + + network.config.use_cache = use_cache + cleanup_memory(verbos=True) + logging.info('-----GPTQ Quantization Done-----\n') + + return network diff --git a/src/chop/passes/module/transforms/gptq/utils.py b/src/chop/passes/module/transforms/gptq/utils.py new file mode 100644 index 000000000..670910e31 --- /dev/null +++ b/src/chop/passes/module/transforms/gptq/utils.py @@ -0,0 +1,41 @@ +import gc +import logging + +import torch + + +def find_qlayers(module, layers=[torch.nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_qlayers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + + +def cleanup_memory(verbos=True) -> None: + """Run GC and clear GPU memory.""" + import inspect + caller_name = '' + try: + caller_name = f' (from {inspect.stack()[1].function})' + except (ValueError, KeyError): + pass + + def total_reserved_mem() -> int: + return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count())) + + memory_before = total_reserved_mem() + + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + memory_after = total_reserved_mem() + if verbos: + logging.info( + f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB" + f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)" + ) diff --git a/src/chop/passes/module/transforms/quantize/quantize.py b/src/chop/passes/module/transforms/quantize/quantize.py index 1ae817795..d96de906f 100644 --- a/src/chop/passes/module/transforms/quantize/quantize.py +++ b/src/chop/passes/module/transforms/quantize/quantize.py @@ -130,6 +130,12 @@ def quantize_module_transform_pass(network, pass_args): :raises ValueError: If the quantize "by" argument is unsupported. """ + # GPTQ pre-pass: quantize linear weights before module replacement + gptq_config = pass_args.pop("gptq", None) + if gptq_config is not None: + from ..gptq import run_gptq + network = run_gptq(network, gptq_config) + by = pass_args.pop("by") match by: case "type": @@ -140,4 +146,5 @@ def quantize_module_transform_pass(network, pass_args): network = quantize_by_regex_name(network, pass_args) case _: raise ValueError(f'Unsupported quantize "by": {by}') + return network, {}