Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1b1e8f4
init calibration less quant refractor
Qubitium Mar 9, 2026
661326b
refractor quant config
Qubitium Mar 9, 2026
781ba2c
refractor quant config 2
Qubitium Mar 9, 2026
995b5da
refractor quant config 3
Qubitium Mar 9, 2026
55084cd
rename calibrationless to weight_only
Qubitium Mar 9, 2026
c00531f
fix awq oom
Qubitium Mar 9, 2026
c1f125b
v6.0.0 update
Qubitium Mar 9, 2026
dd85bc8
cleanup
Qubitium Mar 9, 2026
e2cc88e
stable return tuples
Qubitium Mar 9, 2026
3e29f23
accelerate depend 1.13.0
Qubitium Mar 9, 2026
ee1ba3f
cleanup hf kernel gptq/awq post_init loading
Qubitium Mar 9, 2026
e256c76
fix test
Qubitium Mar 9, 2026
88e4e9b
fix SmoothMAD overly-aggressive clipping: normalize k to behave like …
Qubitium Mar 9, 2026
a3b5ee7
simplify
Qubitium Mar 9, 2026
b5da414
initial gguf
Qubitium Mar 9, 2026
1cb7bca
gguf refractor
Qubitium Mar 9, 2026
75eb3db
gguf refractor
Qubitium Mar 9, 2026
0287831
gguf unit test
Qubitium Mar 10, 2026
9d3c96d
fix gguf should directly bypass rtn with optional smoother
Qubitium Mar 10, 2026
e8da713
add test
Qubitium Mar 10, 2026
c14d624
refractor config
Qubitium Mar 10, 2026
c76acac
refractor config part2
Qubitium Mar 10, 2026
af124eb
gguf dequant to native type, not fp32
Qubitium Mar 10, 2026
c52e82b
fuse gguf ops
Qubitium Mar 10, 2026
8ef27ed
autotune
Qubitium Mar 10, 2026
5e9ac59
fold autotune api to base kernel
Qubitium Mar 10, 2026
c4d1bc6
refractor
Qubitium Mar 10, 2026
764e1bc
Merge origin/main into refractor-simple-quant
Qubitium Mar 10, 2026
ce82b8e
add missing
Qubitium Mar 10, 2026
2c4488f
add missing
Qubitium Mar 10, 2026
a4bdbf7
defuser update fix
Qubitium Mar 10, 2026
50fee0d
optional depend on llama-cpp
Qubitium Mar 10, 2026
f1974a9
add quality tests to guard against regressions
Qubitium Mar 10, 2026
59f8329
refractor naming
Qubitium Mar 10, 2026
f690085
split gguf_cpp_cpu and gguf_cpp_cuda kernel
Qubitium Mar 10, 2026
339b7f7
add gguf_triton fused kernel
Qubitium Mar 10, 2026
f3b1125
check for gguf k block padding
Qubitium Mar 10, 2026
59da42c
optimize gguf triton kernel vs cpp
Qubitium Mar 10, 2026
3be5d85
prioriize gguf auto kernel selection
Qubitium Mar 10, 2026
45266af
Merge remote-tracking branch 'origin/main' into refractor-simple-quant
Qubitium Mar 10, 2026
3313e47
bug fixes
Qubitium Mar 10, 2026
bae1164
use qcfg.format and deprecate qcfg.checkpoint_format as much as possi…
Qubitium Mar 11, 2026
1090f01
protcol design
Qubitium Mar 11, 2026
3de9a80
Merge remote-tracking branch 'origin/main' into refractor-simple-quant
Qubitium Mar 11, 2026
c070082
simplify
Qubitium Mar 11, 2026
6a76fc0
unit test simple protocol
Qubitium Mar 11, 2026
b757c92
refractor match
Qubitium Mar 11, 2026
980778d
update tests
Qubitium Mar 11, 2026
1999c73
update tests
Qubitium Mar 11, 2026
c3e5d0f
clarify failscale placement
Qubitium Mar 12, 2026
007b793
failsafe to fallback rename
Qubitium Mar 12, 2026
c03aab2
add exl3 (exllama v3) support
Qubitium Mar 12, 2026
8bc5bf8
add exl3 unit test
Qubitium Mar 12, 2026
01af1bf
Merge remote-tracking branch 'origin/main' into refractor-simple-quant
Qubitium Mar 12, 2026
097b054
add fp8 support
Qubitium Mar 12, 2026
4271453
soft deprecate `quant_method` for `method`, `checkpoint_format` for `…
Qubitium Mar 12, 2026
1c762ca
hard deprecate qcfg.is_marlin_format
Qubitium Mar 12, 2026
94d5d99
Merge remote-tracking branch 'origin/main' into refractor-simple-quant
Qubitium Mar 12, 2026
28ed59f
resolve github code quality issues
Qubitium Mar 12, 2026
4e1e852
update readme
Qubitium Mar 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
recursive-include gptqmodel_ext/awq *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllamav3 *.h *.hpp *.cuh *.cu *.cpp *.c
recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp *.hpp
recursive-include gptqmodel_ext/machete *.h *.hpp *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/cutlass_extensions *.h *.hpp *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp
recursive-include gptqmodel/exllamav3/util/hadamard_data *.txt
include licenses/*
include gptqmodel_ext/pack_block_cpu.cpp
include gptqmodel_ext/marlin/generate_kernels.py
Expand Down
110 changes: 108 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Fixed quantization of OPT and DeepSeek V2-Lite models. Fixed inference for DeepS
## What is GPT-QModel?
GPT-QModel is a production-ready LLM model compression/quantization toolkit with hw-accelerated inference support for both CPU/GPU via HF Transformers, vLLM, and SGLang.

GPT-QModel currently supports GPTQ, AWQ, QQQ, GPTAQ, EoRa, GAR, with more quantization methods and enhancements planned.
GPT-QModel currently supports GPTQ, AWQ, QQQ, GGUF, FP8, EXL3, GPTAQ, EoRa, and GAR, with more quantization methods and enhancements planned.

## Quantization Support

Expand All @@ -155,16 +155,21 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat
|---------------------------|------------|---|---|---|---------------|
| GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| AWQ | ✅ | ✅ | ✅ | ✅ | ✅ |
| GGUF | ✅ | x | x | x | x |
| FP8 | ✅ | x | x | x | x |
| Exllama V3 / EXL3 | ✅ | x | x | x | x |
| EoRA | ✅ | ✅ | ✅ | ✅ | x |
| Group Aware Act Reordering | ✅ | ✅ | ✅ | ✅ | ✅ |
| QQQ | ✅ | x | x | x | x |
| Rotation | ✅ | x | x | x | x |
| GPTAQ | ✅ | ✅ | ✅ | ✅ | ✅ |

`GGUF`, `FP8`, and `EXL3` are currently native GPT-QModel quantization/runtime paths. `vLLM` and `SGLang` integration currently targets `GPTQ` and `AWQ`.

## Features
* ✨ Native integration with HF [Transformers](https://github.com/huggingface/transformers), [Optimum](https://github.com/huggingface/optimum), and [Peft](https://github.com/huggingface/peft)
* 🚀 [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) inference integration for quantized models with format = `FORMAT.[GPTQ/AWQ]`
* ✨ GPTQ, AWQ, and QQQ quantization format with hardware-accelerated inference kernels.
* ✨ GPTQ, AWQ, QQQ, GGUF, FP8, and EXL3 quantization support.
* 🚀 Quantize MoE models with ease even with extreme routing activation bias via `Moe.Routing` and/or `FailSafe`.
* 🚀 Data Parallelism for 80%+ quantization speed reduction with Multi-GPU.
* 🚀 Optimized for Python >= 3.13t (free threading) with lock-free threading.
Expand All @@ -177,6 +182,15 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat
* 🚀 [Microsoft/BITBLAS](https://github.com/microsoft/BitBLAS) optimized tile based inference.
* 💯 CI unit-test coverage for all supported models and kernels including post-quantization quality regression.

## Who's Using GPT-QModel?

Selected public references where teams or companies explicitly mention `GPTQModel` in documentation, integration notes, or quantized model usage. This is not an exhaustive customer list.

* <img src="https://cdn.simpleicons.org/huggingface/FFD21E" alt="Hugging Face logo" height="14"> Hugging Face
* <img src="https://cdn.simpleicons.org/intel/0071C5" alt="Intel logo" height="14"> Intel
* <img src="https://cdn.simpleicons.org/nvidia/76B900" alt="NVIDIA logo" height="14"> NVIDIA
* <img src="https://cdn.simpleicons.org/alibabacloud/FF6A00" alt="Alibaba Cloud logo" height="14"> Alibaba Cloud


## Quality: GPTQ 4bit can match native BF16:
🤗 [ModelCloud quantized Vortex models on HF](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2)
Expand Down Expand Up @@ -289,6 +303,76 @@ model.quantize(calibration_dataset, batch_size=1)
model.save(quant_path)
```

#### Other Quantization Formats

`GPTQ`, `AWQ`, and `EXL3` are calibration-based. `GGUF` and `FP8` are weight-only and should be quantized with `calibration=None`.

##### GGUF Example: Llama 3.2 1B Instruct

```py
from gptqmodel import BACKEND, GGUFConfig, GPTQModel

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-GGUF-Q4_K_M"

qcfg = GGUFConfig(
bits=4,
format="q_k_m",
)

model = GPTQModel.load(model_id, qcfg)
model.quantize(calibration=None, backend=BACKEND.GGUF_TORCH)
model.save(quant_path)
```

##### FP8 Example: Llama 3.2 1B Instruct

```py
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-FP8-E4M3"

qcfg = QuantizeConfig(
method="fp8",
format="float8_e4m3fn", # or "float8_e5m2"
bits=8,
weight_scale_method="row",
)

model = GPTQModel.load(model_id, qcfg)
model.quantize(calibration=None, backend=BACKEND.TORCH)
model.save(quant_path)
```

##### Exllama V3 / EXL3 Example: Llama 3.2 1B Instruct

```py
from datasets import load_dataset
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-EXL3"

calibration_dataset = load_dataset(
"allenai/c4",
data_files="en/c4-train.00001-of-01024.json.gz",
split="train",
).select(range(1024))["text"]

qcfg = QuantizeConfig(
method="exl3",
format="exl3",
bits=4.0, # target average bits-per-weight
head_bits=6.0, # optional higher bitrate for attention heads / sensitive tensors
codebook="mcg", # one of: mcg, mul1, 3inst
)

model = GPTQModel.load(model_id, qcfg)
model.quantize(calibration_dataset, batch_size=1, backend=BACKEND.EXLLAMA_V3)
model.save(quant_path)
```

#### MoE Quantization

Some MoE (mixture of experts) models have extremely uneven/biased routing (distribution of tokens) to the `experts` causing some expert modules to receive close-to-zero activated tokens, thus failing to complete calibration-based quantization (GPTQ/AWQ).
Expand Down Expand Up @@ -489,6 +573,17 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi
year={2023}
}

# GGUF / llama.cpp
@misc{ggerganov2023gguf,
author = {Georgi Gerganov and ggml-org contributors},
title = {llama.cpp and the GGUF model format},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ggml-org/llama.cpp}},
note = {Canonical GGUF implementation and format reference; see also \url{https://github.com/ggml-org/llama.cpp/wiki/dev-notes}},
year = {2023}
}

# EoRA
@article{liu2024eora,
title={EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation},
Expand Down Expand Up @@ -528,6 +623,17 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi
journal={arXiv preprint arXiv:2406.09904},
year={2024}
}

# ExLlama V3 / EXL3
@misc{turboderp2026exllamav3,
author = {turboderp and exllamav3 contributors},
title = {ExLlamaV3 and the EXL3 quantization format},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/turboderp-org/exllamav3}},
note = {Project repository and EXL3 format documentation: \url{https://github.com/turboderp-org/exllamav3/blob/master/doc/exl3.md}},
year = {2026}
}
```

## Quick Notes
Expand Down
96 changes: 96 additions & 0 deletions examples/quantization/llama3_2_1b_gguf_q4_k_m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0

import logging
import os
from pathlib import Path

import torch
from transformers import AutoTokenizer

from gptqmodel import BACKEND, GPTQModel
from gptqmodel.quantization import GGUFConfig


os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:256")

SOURCE_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct"
OUTPUT_DIR = "./Llama-3.2-1B-Instruct-GGUF-Q4_K_M"


def main() -> None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(SOURCE_MODEL, use_fast=True)

qconfig = GGUFConfig(
bits=4,
format="q_k_m",
smoother=None,
offload_to_disk=True,
offload_to_disk_path="./gptqmodel_offload",
)

print("Resolved quantize config:")
print(f" type = {type(qconfig).__name__}")
print(f" format = {qconfig.format}")
print(f" bits = {qconfig.bits!r}")
print(f" bits_s = {str(qconfig.bits)}")

model = GPTQModel.from_pretrained(
model_id_or_path=SOURCE_MODEL,
quantize_config=qconfig,
trust_remote_code=False,
)

quant_log = model.quantize(
calibration=None,
tokenizer=tokenizer,
backend=BACKEND.GGUF_TORCH,
)
print("Quantize lifecycle keys:", list(quant_log.keys()))

out_dir = Path(OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)

model.save(str(out_dir))
tokenizer.save_pretrained(str(out_dir))

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

quantized = GPTQModel.from_quantized(
model_id_or_path=str(out_dir),
backend=BACKEND.GGUF_TORCH,
device=device,
trust_remote_code=False,
)

print("Inference kernel:", quantized.qlinear_kernel.__name__)

prompt = "Which city is the capital city of France?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

with torch.inference_mode():
output_ids = quantized.generate(
**inputs,
max_new_tokens=48,
do_sample=False,
)

text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("\nPrompt:")
print(prompt)
print("\nGeneration:")
print(text)


if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
12 changes: 11 additions & 1 deletion gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@

from .models import GPTQModel, get_best_device
from .models.auto import ASCII_LOGO
from .quantization import BaseQuantizeConfig, GPTAQConfig, QuantizeConfig
from .quantization import (
AWQQuantizeConfig,
BaseQuantizeConfig,
GGUFConfig,
GGUFQuantizeConfig,
GPTAQConfig,
GPTQQuantizeConfig,
QuantizeConfig,
RTNQuantizeConfig,
WeightOnlyConfig,
)
from .utils import BACKEND
from .utils.exllama import exllama_set_max_input_length
from .version import __version__
Expand Down
7 changes: 5 additions & 2 deletions gptqmodel/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
# out = out + ((x @ self.lora_A) @ self.lora_B)

# native quantized model/eora is float16 for gptq but for training, we may load the model as bfloat16 for accuracy
if x.dtype != self.lora_A.dtype:
log.info.once(f"Adapter: Lora A/B auto changed from `{self.lora_A.dtype}` to `{x.dtype}` to match forward input dtype.")
if x.dtype != self.lora_A.dtype or x.device != self.lora_A.device:
log.info.once(
f"Adapter: Lora A/B auto changed from `{self.lora_A.dtype}` on `{self.lora_A.device}` "
f"to `{x.dtype}` on `{x.device}` to match forward input."
)
self.lora_A = self.lora_A.to(device=x.device, dtype=x.dtype)
self.lora_B = self.lora_B.to(device=x.device, dtype=x.dtype)

Expand Down
13 changes: 13 additions & 0 deletions gptqmodel/exllamav3/CREDITS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
This directory vendors the EXL3 kernel and quantizer pieces adapted from `turboderp-org/exllamav3`.

Primary upstream source:
- https://github.com/turboderp-org/exllamav3

Ported components in this repo:
- `gptqmodel/exllamav3/ext.py`
- `gptqmodel/exllamav3/modules/quant/exl3.py`
- `gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py`
- `gptqmodel/exllamav3/util/*`
- `gptqmodel_ext/exllamav3/*`

The code remains self-contained inside GPTQModel and does not depend on the external `exllamav3` Python package.
Empty file added gptqmodel/exllamav3/__init__.py
Empty file.
Loading
Loading