Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ qtorch/*ipynb*
**/*/*.so
example/runs
example/data

# HIP build artifacts
qtorch/quant/quant_hip/*_hip.hip
qtorch/quant/quant_hip/*_hip.h
**/cifar10
**/cifar100
example/checkpoint
Expand All @@ -21,3 +25,4 @@ dist
**/data
docs/source/examples
playground/
test_results_20260106_154839.log
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# QPyTorch


# QPyTorch (ROCm Fork)
[![Downloads](https://pepy.tech/badge/qtorch)](https://pepy.tech/project/qtorch) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

#### News:
- **ROCm Support Added**: This fork adds full AMD GPU support via ROCm/HIP, enabling low-precision arithmetic simulation on AMD hardware.
- Updated to version 0.3.0:
- supporting subnorms now (#43). Thanks @danielholanda for his contribution!
- Updated to version 0.2.0:
Expand All @@ -28,7 +31,10 @@ Notably, QPyTorch supports quantizing different numbers in the training process
with customized low-precision formats. This eases the process of investigating
different precision settings and developing new deep learning architectures. More
concretely, QPyTorch implements fused kernels for quantization and integrates
smoothly with existing PyTorch kernels (e.g. matrix multiplication, convolution).
smoothly with existing PyTorch kernels (e.g. matrix multiplication, convolution).

**This fork extends QPyTorch with AMD GPU support through ROCm/HIP**, enabling the same low-precision
arithmetic simulations on AMD hardware that were previously only available on NVIDIA GPUs.

Recent researches can be reimplemented easily through QPyTorch. We offer an
example replication of [WAGE](https://arxiv.org/abs/1802.04680) in a downstream
Expand Down Expand Up @@ -63,7 +69,8 @@ requirements:
- Python >= 3.6
- PyTorch >= 1.5.0
- GCC >= 4.9 on linux
- CUDA >= 10.1 on linux
- **For NVIDIA GPUs**: CUDA >= 10.1 on linux
- **For AMD GPUs**: ROCm 7.1.0 (tested version; other versions may work but are untested)

Install other requirements by:
```bash
Expand All @@ -75,6 +82,15 @@ Install QPyTorch through pip:
pip install qtorch
```

Or install from source (for ROCm support):
```bash
git clone https://github.com/Hrzwahusa/qtorch-rocm.git
cd qtorch-rocm
pip install -e .
```

The build system will automatically detect whether to build with CUDA or ROCm/HIP based on your environment.

For more details about compiler requirements,
please refer to [PyTorch extension tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html).

Expand Down
19 changes: 18 additions & 1 deletion qtorch/quant/quant_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@
],
)

if torch.cuda.is_available():
# Check if ROCm/HIP is available
if torch.cuda.is_available() and hasattr(torch.version, 'hip') and torch.version.hip is not None:
print("Loading QPyTorch with ROCm/HIP support...")
quant_cuda = load(
name="quant_cuda",
sources=[
os.path.join(current_path, "quant_hip/quant_hip.cpp"),
os.path.join(current_path, "quant_hip/bit_helper.hip"),
os.path.join(current_path, "quant_hip/sim_helper.hip"),
os.path.join(current_path, "quant_hip/block_kernel.hip"),
os.path.join(current_path, "quant_hip/float_kernel.hip"),
os.path.join(current_path, "quant_hip/fixed_point_kernel.hip"),
os.path.join(current_path, "quant_hip/quant.hip"),
],
extra_include_paths=[os.path.join(current_path, "quant_hip")],
)
elif torch.cuda.is_available():
print("Loading QPyTorch with CUDA support...")
quant_cuda = load(
name="quant_cuda",
sources=[
Expand Down
59 changes: 59 additions & 0 deletions qtorch/quant/quant_hip/bit_helper.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#define FLOAT_TO_BITS(x) (*reinterpret_cast<unsigned int*>(x))
#define BITS_TO_FLOAT(x) (*reinterpret_cast<float*>(x))

__device__ __inline__ unsigned int extract_exponent(float *a) {
unsigned int temp = *(reinterpret_cast<unsigned int*>(a));
temp = (temp << 1 >> 24); // single preciision, 1 sign bit, 23 mantissa bits
return temp-127+1; // exponent offset and virtual bit
}

__device__ __inline__ unsigned int round_bitwise_stochastic(unsigned int target,
unsigned int rand_prob,
int man_bits) {
unsigned int mask = (1 << (23-man_bits)) - 1;
unsigned int add_r = target+(rand_prob & mask);
unsigned int quantized = add_r & ~mask;
return quantized;
}

__device__ __inline__ unsigned int round_bitwise_nearest(unsigned int target,
int man_bits) {
unsigned int mask = (1 << (23-man_bits)) - 1;
unsigned int rand_prob = 1 << (23-man_bits-1);
unsigned int add_r = target+rand_prob;
unsigned int quantized = add_r & ~mask;
return quantized;
}

__device__ __inline__ unsigned int clip_exponent(int exp_bits, int man_bits,
unsigned int old_num,
unsigned int quantized_num) {
if (quantized_num == 0)
return quantized_num;

int quantized_exponent_store = quantized_num << 1 >> 1 >> 23; // 1 sign bit, 23 mantissa bits
int max_exponent_store = (1 << (exp_bits - 1)) + 127; // we are not reserving an exponent bit for infinity, nan, etc
// Clippping Value Up
if (quantized_exponent_store > max_exponent_store)
{
unsigned int max_man = (unsigned int)-1 << 9 >> 9 >> (23 - man_bits) << (23 - man_bits); // 1 sign bit, 8 exponent bits, 1 virtual bit
unsigned int max_num = ((unsigned int)max_exponent_store << 23) | max_man;
unsigned int old_sign = old_num >> 31 << 31;
quantized_num = old_sign | max_num;
}
return quantized_num;
}


__device__ __inline__ unsigned int clip_max_exponent(int man_bits,
unsigned int max_exponent,
unsigned int quantized_num) {
unsigned int quantized_exponent = quantized_num << 1 >> 24 << 23; // 1 sign bit, 23 mantissa bits
if (quantized_exponent > max_exponent) {
unsigned int max_man = (unsigned int ) -1 << 9 >> 9 >> (23-man_bits) << (23-man_bits); // 1 sign bit, 8 exponent bits
unsigned int max_num = max_exponent | max_man;
unsigned int old_sign = quantized_num >> 31 << 31;
quantized_num = old_sign | max_num;
}
return quantized_num;
}
83 changes: 83 additions & 0 deletions qtorch/quant/quant_hip/block_kernel.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "quant_kernel.h"
#include "sim_helper.hip"
#include "bit_helper.hip"

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void block_kernel_stochastic(float* __restrict__ a,
int* __restrict__ r,
float* o, int size,
float* __restrict__ max_entry,
int man_bits) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
unsigned int max_entry_bits = FLOAT_TO_BITS(&max_entry[index]);
unsigned int max_exp = max_entry_bits << 1 >> 24 << 23;
float base_float = 6*BITS_TO_FLOAT(&max_exp);

float target_rebase = a[index]+base_float;
unsigned int target_bits = FLOAT_TO_BITS(&target_rebase);
unsigned int rand_prob = (unsigned int) r[index];
unsigned int quantized = round_bitwise_stochastic(target_bits, rand_prob, man_bits);
float quantize_float = BITS_TO_FLOAT(&quantized)-base_float;

unsigned int quantize_bits = FLOAT_TO_BITS(&quantize_float) ;
unsigned int clip_quantize = clip_max_exponent(man_bits-2, max_exp, quantize_bits);
quantize_float = BITS_TO_FLOAT(&clip_quantize);
o[index] = quantize_float;
}
}

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void block_kernel_nearest(float* __restrict__ a,
float* o, int size,
float* __restrict__ max_entry,
int man_bits) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
unsigned int max_entry_bits = FLOAT_TO_BITS(&max_entry[index]);
unsigned int max_exp = max_entry_bits << 1 >> 24 << 23;
float base_float = 6*BITS_TO_FLOAT(&max_exp);

float target_rebase = a[index]+base_float;
unsigned int target_bits = FLOAT_TO_BITS(&target_rebase);
unsigned int quantized = round_bitwise_nearest(target_bits, man_bits);
float quantize_float = BITS_TO_FLOAT(&quantized)-base_float;

unsigned int quantize_bits = FLOAT_TO_BITS(&quantize_float);
unsigned int clip_quantize = clip_max_exponent(man_bits-2, max_exp, quantize_bits); // sign bit, virtual bit
quantize_float = BITS_TO_FLOAT(&clip_quantize);

o[index] = quantize_float;
}
}

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void block_kernel_sim_stochastic(float* __restrict__ a,
float* __restrict__ r,
float* o, int size,
float* max_entry,
int wl) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
int exponent = ((int) extract_exponent(max_entry));
int sigma = exponent-(wl-1);
o[index] = round(a[index], r[index], sigma);
}
}

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void block_kernel_sim_nearest(float* __restrict__ a,
float* o, int size,
float* max_entry,
int wl) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
int exponent = ((int) extract_exponent(max_entry));
int sigma = exponent-(wl-1);
o[index] = nearest_round(a[index], sigma);
}
}
77 changes: 77 additions & 0 deletions qtorch/quant/quant_hip/fixed_point_kernel.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "quant_kernel.h"
#include "sim_helper.hip"


template <typename T>
__device__ __inline__ T clamp_helper(T a, T min, T max) {
if (a > max) return max;
else if (a < min) return min;
else return a;
}

template <typename T>
__device__ __inline__ T clamp_mask_helper(T a, T min, T max, uint8_t* mask) {
if (a > max) {
*mask = 1;
return max;
} else if (a < min) {
*mask = 1;
return min;
}
*mask = 0;
return a;
}

// quantize an array of real numbers into fixed point with word length [wl] and [fl] fractional bits
// 2**-[sigma] is the smallest unit of the fixed point representation. Stochastic Rounding with r.
__global__ void fixed_point_quantize_kernel_stochastic(float* __restrict__ a,
float* __restrict__ r,
float* o, int size,
int sigma, bool use_clamp,
float t_min, float t_max) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
o[index] = round(a[index], r[index], sigma);
if (use_clamp) {
o[index] = clamp_helper(o[index], t_min, t_max);
}
}
}

// quantize an array of real numbers into fixed point with word length [wl] and [fl] fractional bits
// 2**-[sigma] is the smallest unit of the fixed point representation. Nearest Neighbor Rounding.
__global__ void fixed_point_quantize_kernel_nearest(float* __restrict__ a,
float* o, int size,
int sigma, bool use_clamp,
float t_min, float t_max) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
o[index] = nearest_round(a[index], sigma);
if (use_clamp) {
o[index] = clamp_helper(o[index], t_min, t_max);
}
}
}

__global__ void fixed_point_quantize_kernel_mask_stochastic(float* __restrict__ a,
float* __restrict__ r,
float* o, uint8_t* m,
int size, int sigma,
float t_min, float t_max) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
o[index] = round(a[index], r[index], sigma);
o[index] = clamp_mask_helper(o[index], t_min, t_max, m+index);
}
}

__global__ void fixed_point_quantize_kernel_mask_nearest(float* __restrict__ a,
float* o, uint8_t* m,
int size, int sigma,
float t_min, float t_max) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
o[index] = nearest_round(a[index], sigma);
o[index] = clamp_mask_helper(o[index], t_min, t_max, m+index);
}
}
70 changes: 70 additions & 0 deletions qtorch/quant/quant_hip/float_kernel.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "quant_kernel.h"
#include "bit_helper.hip"

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void float_kernel_stochastic(float* __restrict__ a,
int* __restrict__ r,
float* o, int size,
int man_bits,
int exp_bits) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
unsigned int rand_prob = (unsigned int) r[index];
unsigned int target,quantize_bits;
target = FLOAT_TO_BITS(&a[index]);
float quantized;

int target_exp = (target << 1 >> 1 >> 23) -127;
int min_exp = -((1 << (exp_bits - 1)) - 2);
bool subnormal = (target_exp < min_exp);
if (subnormal){
float shift_float,val;
int shift_bits = ((127+min_exp)<<23) | (target >> 31 <<31);
shift_float = BITS_TO_FLOAT(&shift_bits);
val=a[index]+shift_float;
target = FLOAT_TO_BITS(&val);
quantize_bits = round_bitwise_stochastic(target, rand_prob, man_bits);
quantized = BITS_TO_FLOAT(&quantize_bits) - shift_float;
}
else{
quantize_bits = round_bitwise_stochastic(target, rand_prob, man_bits);
quantize_bits = clip_exponent(exp_bits, man_bits, target, quantize_bits);
quantized = BITS_TO_FLOAT(&quantize_bits);
}
o[index] = quantized;
}
}

// quantize a float into a floating point with [exp_bits] exponent and
// [man_bits] mantissa
__global__ void float_kernel_nearest(float* __restrict__ a,
float* o, int size,
int man_bits,
int exp_bits) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
unsigned int target,quantize_bits;
target = FLOAT_TO_BITS(&a[index]);
float quantized;

int target_exp = (target << 1 >> 1 >> 23) -127;
int min_exp = -((1 << (exp_bits - 1)) - 2);
bool subnormal = (target_exp < min_exp);
if (subnormal){
float shift_float,val;
int shift_bits = ((127+min_exp)<<23) | (target >> 31 <<31);
shift_float = BITS_TO_FLOAT(&shift_bits);
val=a[index]+shift_float;
target = FLOAT_TO_BITS(&val);
quantize_bits = round_bitwise_nearest(target, man_bits);
quantized = BITS_TO_FLOAT(&quantize_bits) - shift_float;
}
else{
quantize_bits = round_bitwise_nearest(target, man_bits);
quantize_bits = clip_exponent(exp_bits, man_bits, target, quantize_bits);
quantized = BITS_TO_FLOAT(&quantize_bits);
}
o[index] = quantized;
}
}
Loading