diff --git a/.gitignore b/.gitignore index 08c16b8..99994c3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,10 @@ qtorch/*ipynb* **/*/*.so example/runs example/data + +# HIP build artifacts +qtorch/quant/quant_hip/*_hip.hip +qtorch/quant/quant_hip/*_hip.h **/cifar10 **/cifar100 example/checkpoint @@ -21,3 +25,4 @@ dist **/data docs/source/examples playground/ +test_results_20260106_154839.log diff --git a/README.md b/README.md index 46556e2..50a5284 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,10 @@ -# QPyTorch + + +# QPyTorch (ROCm Fork) [![Downloads](https://pepy.tech/badge/qtorch)](https://pepy.tech/project/qtorch) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) #### News: +- **ROCm Support Added**: This fork adds full AMD GPU support via ROCm/HIP, enabling low-precision arithmetic simulation on AMD hardware. - Updated to version 0.3.0: - supporting subnorms now (#43). Thanks @danielholanda for his contribution! - Updated to version 0.2.0: @@ -28,7 +31,10 @@ Notably, QPyTorch supports quantizing different numbers in the training process with customized low-precision formats. This eases the process of investigating different precision settings and developing new deep learning architectures. More concretely, QPyTorch implements fused kernels for quantization and integrates -smoothly with existing PyTorch kernels (e.g. matrix multiplication, convolution). +smoothly with existing PyTorch kernels (e.g. matrix multiplication, convolution). + +**This fork extends QPyTorch with AMD GPU support through ROCm/HIP**, enabling the same low-precision +arithmetic simulations on AMD hardware that were previously only available on NVIDIA GPUs. Recent researches can be reimplemented easily through QPyTorch. We offer an example replication of [WAGE](https://arxiv.org/abs/1802.04680) in a downstream @@ -63,7 +69,8 @@ requirements: - Python >= 3.6 - PyTorch >= 1.5.0 - GCC >= 4.9 on linux -- CUDA >= 10.1 on linux +- **For NVIDIA GPUs**: CUDA >= 10.1 on linux +- **For AMD GPUs**: ROCm 7.1.0 (tested version; other versions may work but are untested) Install other requirements by: ```bash @@ -75,6 +82,15 @@ Install QPyTorch through pip: pip install qtorch ``` +Or install from source (for ROCm support): +```bash +git clone https://github.com/Hrzwahusa/qtorch-rocm.git +cd qtorch-rocm +pip install -e . +``` + +The build system will automatically detect whether to build with CUDA or ROCm/HIP based on your environment. + For more details about compiler requirements, please refer to [PyTorch extension tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html). diff --git a/qtorch/quant/quant_function.py b/qtorch/quant/quant_function.py index 061676c..c2fb223 100644 --- a/qtorch/quant/quant_function.py +++ b/qtorch/quant/quant_function.py @@ -16,7 +16,24 @@ ], ) -if torch.cuda.is_available(): +# Check if ROCm/HIP is available +if torch.cuda.is_available() and hasattr(torch.version, 'hip') and torch.version.hip is not None: + print("Loading QPyTorch with ROCm/HIP support...") + quant_cuda = load( + name="quant_cuda", + sources=[ + os.path.join(current_path, "quant_hip/quant_hip.cpp"), + os.path.join(current_path, "quant_hip/bit_helper.hip"), + os.path.join(current_path, "quant_hip/sim_helper.hip"), + os.path.join(current_path, "quant_hip/block_kernel.hip"), + os.path.join(current_path, "quant_hip/float_kernel.hip"), + os.path.join(current_path, "quant_hip/fixed_point_kernel.hip"), + os.path.join(current_path, "quant_hip/quant.hip"), + ], + extra_include_paths=[os.path.join(current_path, "quant_hip")], + ) +elif torch.cuda.is_available(): + print("Loading QPyTorch with CUDA support...") quant_cuda = load( name="quant_cuda", sources=[ diff --git a/qtorch/quant/quant_hip/bit_helper.hip b/qtorch/quant/quant_hip/bit_helper.hip new file mode 100644 index 0000000..9fed4e9 --- /dev/null +++ b/qtorch/quant/quant_hip/bit_helper.hip @@ -0,0 +1,59 @@ +#define FLOAT_TO_BITS(x) (*reinterpret_cast(x)) +#define BITS_TO_FLOAT(x) (*reinterpret_cast(x)) + +__device__ __inline__ unsigned int extract_exponent(float *a) { + unsigned int temp = *(reinterpret_cast(a)); + temp = (temp << 1 >> 24); // single preciision, 1 sign bit, 23 mantissa bits + return temp-127+1; // exponent offset and virtual bit +} + +__device__ __inline__ unsigned int round_bitwise_stochastic(unsigned int target, + unsigned int rand_prob, + int man_bits) { + unsigned int mask = (1 << (23-man_bits)) - 1; + unsigned int add_r = target+(rand_prob & mask); + unsigned int quantized = add_r & ~mask; + return quantized; +} + +__device__ __inline__ unsigned int round_bitwise_nearest(unsigned int target, + int man_bits) { + unsigned int mask = (1 << (23-man_bits)) - 1; + unsigned int rand_prob = 1 << (23-man_bits-1); + unsigned int add_r = target+rand_prob; + unsigned int quantized = add_r & ~mask; + return quantized; +} + +__device__ __inline__ unsigned int clip_exponent(int exp_bits, int man_bits, + unsigned int old_num, + unsigned int quantized_num) { + if (quantized_num == 0) + return quantized_num; + + int quantized_exponent_store = quantized_num << 1 >> 1 >> 23; // 1 sign bit, 23 mantissa bits + int max_exponent_store = (1 << (exp_bits - 1)) + 127; // we are not reserving an exponent bit for infinity, nan, etc + // Clippping Value Up + if (quantized_exponent_store > max_exponent_store) + { + unsigned int max_man = (unsigned int)-1 << 9 >> 9 >> (23 - man_bits) << (23 - man_bits); // 1 sign bit, 8 exponent bits, 1 virtual bit + unsigned int max_num = ((unsigned int)max_exponent_store << 23) | max_man; + unsigned int old_sign = old_num >> 31 << 31; + quantized_num = old_sign | max_num; + } + return quantized_num; +} + + +__device__ __inline__ unsigned int clip_max_exponent(int man_bits, + unsigned int max_exponent, + unsigned int quantized_num) { + unsigned int quantized_exponent = quantized_num << 1 >> 24 << 23; // 1 sign bit, 23 mantissa bits + if (quantized_exponent > max_exponent) { + unsigned int max_man = (unsigned int ) -1 << 9 >> 9 >> (23-man_bits) << (23-man_bits); // 1 sign bit, 8 exponent bits + unsigned int max_num = max_exponent | max_man; + unsigned int old_sign = quantized_num >> 31 << 31; + quantized_num = old_sign | max_num; + } + return quantized_num; +} diff --git a/qtorch/quant/quant_hip/block_kernel.hip b/qtorch/quant/quant_hip/block_kernel.hip new file mode 100644 index 0000000..979da5f --- /dev/null +++ b/qtorch/quant/quant_hip/block_kernel.hip @@ -0,0 +1,83 @@ +#include "quant_kernel.h" +#include "sim_helper.hip" +#include "bit_helper.hip" + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void block_kernel_stochastic(float* __restrict__ a, + int* __restrict__ r, + float* o, int size, + float* __restrict__ max_entry, + int man_bits) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + unsigned int max_entry_bits = FLOAT_TO_BITS(&max_entry[index]); + unsigned int max_exp = max_entry_bits << 1 >> 24 << 23; + float base_float = 6*BITS_TO_FLOAT(&max_exp); + + float target_rebase = a[index]+base_float; + unsigned int target_bits = FLOAT_TO_BITS(&target_rebase); + unsigned int rand_prob = (unsigned int) r[index]; + unsigned int quantized = round_bitwise_stochastic(target_bits, rand_prob, man_bits); + float quantize_float = BITS_TO_FLOAT(&quantized)-base_float; + + unsigned int quantize_bits = FLOAT_TO_BITS(&quantize_float) ; + unsigned int clip_quantize = clip_max_exponent(man_bits-2, max_exp, quantize_bits); + quantize_float = BITS_TO_FLOAT(&clip_quantize); + o[index] = quantize_float; + } +} + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void block_kernel_nearest(float* __restrict__ a, + float* o, int size, + float* __restrict__ max_entry, + int man_bits) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + unsigned int max_entry_bits = FLOAT_TO_BITS(&max_entry[index]); + unsigned int max_exp = max_entry_bits << 1 >> 24 << 23; + float base_float = 6*BITS_TO_FLOAT(&max_exp); + + float target_rebase = a[index]+base_float; + unsigned int target_bits = FLOAT_TO_BITS(&target_rebase); + unsigned int quantized = round_bitwise_nearest(target_bits, man_bits); + float quantize_float = BITS_TO_FLOAT(&quantized)-base_float; + + unsigned int quantize_bits = FLOAT_TO_BITS(&quantize_float); + unsigned int clip_quantize = clip_max_exponent(man_bits-2, max_exp, quantize_bits); // sign bit, virtual bit + quantize_float = BITS_TO_FLOAT(&clip_quantize); + + o[index] = quantize_float; + } +} + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void block_kernel_sim_stochastic(float* __restrict__ a, + float* __restrict__ r, + float* o, int size, + float* max_entry, + int wl) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + int exponent = ((int) extract_exponent(max_entry)); + int sigma = exponent-(wl-1); + o[index] = round(a[index], r[index], sigma); + } +} + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void block_kernel_sim_nearest(float* __restrict__ a, + float* o, int size, + float* max_entry, + int wl) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + int exponent = ((int) extract_exponent(max_entry)); + int sigma = exponent-(wl-1); + o[index] = nearest_round(a[index], sigma); + } +} diff --git a/qtorch/quant/quant_hip/fixed_point_kernel.hip b/qtorch/quant/quant_hip/fixed_point_kernel.hip new file mode 100644 index 0000000..a37a301 --- /dev/null +++ b/qtorch/quant/quant_hip/fixed_point_kernel.hip @@ -0,0 +1,77 @@ +#include "quant_kernel.h" +#include "sim_helper.hip" + + +template +__device__ __inline__ T clamp_helper(T a, T min, T max) { + if (a > max) return max; + else if (a < min) return min; + else return a; +} + +template +__device__ __inline__ T clamp_mask_helper(T a, T min, T max, uint8_t* mask) { + if (a > max) { + *mask = 1; + return max; + } else if (a < min) { + *mask = 1; + return min; + } + *mask = 0; + return a; +} + +// quantize an array of real numbers into fixed point with word length [wl] and [fl] fractional bits +// 2**-[sigma] is the smallest unit of the fixed point representation. Stochastic Rounding with r. +__global__ void fixed_point_quantize_kernel_stochastic(float* __restrict__ a, + float* __restrict__ r, + float* o, int size, + int sigma, bool use_clamp, + float t_min, float t_max) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + o[index] = round(a[index], r[index], sigma); + if (use_clamp) { + o[index] = clamp_helper(o[index], t_min, t_max); + } + } +} + +// quantize an array of real numbers into fixed point with word length [wl] and [fl] fractional bits +// 2**-[sigma] is the smallest unit of the fixed point representation. Nearest Neighbor Rounding. +__global__ void fixed_point_quantize_kernel_nearest(float* __restrict__ a, + float* o, int size, + int sigma, bool use_clamp, + float t_min, float t_max) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + o[index] = nearest_round(a[index], sigma); + if (use_clamp) { + o[index] = clamp_helper(o[index], t_min, t_max); + } + } +} + +__global__ void fixed_point_quantize_kernel_mask_stochastic(float* __restrict__ a, + float* __restrict__ r, + float* o, uint8_t* m, + int size, int sigma, + float t_min, float t_max) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + o[index] = round(a[index], r[index], sigma); + o[index] = clamp_mask_helper(o[index], t_min, t_max, m+index); + } +} + +__global__ void fixed_point_quantize_kernel_mask_nearest(float* __restrict__ a, + float* o, uint8_t* m, + int size, int sigma, + float t_min, float t_max) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + o[index] = nearest_round(a[index], sigma); + o[index] = clamp_mask_helper(o[index], t_min, t_max, m+index); + } +} \ No newline at end of file diff --git a/qtorch/quant/quant_hip/float_kernel.hip b/qtorch/quant/quant_hip/float_kernel.hip new file mode 100644 index 0000000..07ab146 --- /dev/null +++ b/qtorch/quant/quant_hip/float_kernel.hip @@ -0,0 +1,70 @@ +#include "quant_kernel.h" +#include "bit_helper.hip" + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void float_kernel_stochastic(float* __restrict__ a, + int* __restrict__ r, + float* o, int size, + int man_bits, + int exp_bits) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + unsigned int rand_prob = (unsigned int) r[index]; + unsigned int target,quantize_bits; + target = FLOAT_TO_BITS(&a[index]); + float quantized; + + int target_exp = (target << 1 >> 1 >> 23) -127; + int min_exp = -((1 << (exp_bits - 1)) - 2); + bool subnormal = (target_exp < min_exp); + if (subnormal){ + float shift_float,val; + int shift_bits = ((127+min_exp)<<23) | (target >> 31 <<31); + shift_float = BITS_TO_FLOAT(&shift_bits); + val=a[index]+shift_float; + target = FLOAT_TO_BITS(&val); + quantize_bits = round_bitwise_stochastic(target, rand_prob, man_bits); + quantized = BITS_TO_FLOAT(&quantize_bits) - shift_float; + } + else{ + quantize_bits = round_bitwise_stochastic(target, rand_prob, man_bits); + quantize_bits = clip_exponent(exp_bits, man_bits, target, quantize_bits); + quantized = BITS_TO_FLOAT(&quantize_bits); + } + o[index] = quantized; + } +} + +// quantize a float into a floating point with [exp_bits] exponent and +// [man_bits] mantissa +__global__ void float_kernel_nearest(float* __restrict__ a, + float* o, int size, + int man_bits, + int exp_bits) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + unsigned int target,quantize_bits; + target = FLOAT_TO_BITS(&a[index]); + float quantized; + + int target_exp = (target << 1 >> 1 >> 23) -127; + int min_exp = -((1 << (exp_bits - 1)) - 2); + bool subnormal = (target_exp < min_exp); + if (subnormal){ + float shift_float,val; + int shift_bits = ((127+min_exp)<<23) | (target >> 31 <<31); + shift_float = BITS_TO_FLOAT(&shift_bits); + val=a[index]+shift_float; + target = FLOAT_TO_BITS(&val); + quantize_bits = round_bitwise_nearest(target, man_bits); + quantized = BITS_TO_FLOAT(&quantize_bits) - shift_float; + } + else{ + quantize_bits = round_bitwise_nearest(target, man_bits); + quantize_bits = clip_exponent(exp_bits, man_bits, target, quantize_bits); + quantized = BITS_TO_FLOAT(&quantize_bits); + } + o[index] = quantized; + } +} diff --git a/qtorch/quant/quant_hip/quant.hip b/qtorch/quant/quant_hip/quant.hip new file mode 100644 index 0000000..aed2a60 --- /dev/null +++ b/qtorch/quant/quant_hip/quant.hip @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "quant_cuda.h" +#include "quant_kernel.h" + +using namespace at; + +Tensor get_max_entry(Tensor a, int dim) { + Tensor max_entry; + if (dim == -1) { + max_entry = at::max(at::abs(a)).expand_as(a).contiguous(); + } else if (dim == 0) { + Tensor input_view = a.view({a.size(0), -1}); + max_entry = std::get<0>(input_view.abs().max(1, true)).expand_as(input_view).view_as(a).contiguous(); + } else { + Tensor input_transpose = a.transpose(0, dim); + Tensor input_view = input_transpose.contiguous().view({input_transpose.size(0), -1}); + Tensor max_transpose = std::get<0>(input_view.abs().max(1, true)).expand_as(input_view).view_as(input_transpose); + max_entry = max_transpose.transpose(dim, 0).contiguous(); + } + return max_entry; +} + +Tensor block_quantize_stochastic_cuda(Tensor a, int wl, int dim) { + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + auto rand_ints = randint_like(a, INT_MAX, device(kCUDA).dtype(kInt)); + int64_t size = a.numel(); + + Tensor max_entry = get_max_entry(a, dim); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + block_kernel_stochastic<<>>(a.data_ptr(), + rand_ints.data_ptr(), + o.data_ptr(), + size, + max_entry.data_ptr(), + wl); + return o; +} + +Tensor block_quantize_nearest_cuda(Tensor a, int wl, int dim) { + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + int64_t size = a.numel(); + + Tensor max_entry = get_max_entry(a, dim); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + block_kernel_nearest<<>>(a.data_ptr(), + o.data_ptr(), + size, + max_entry.data_ptr(), + wl); + return o; +} + +Tensor block_quantize_sim_stochastic_cuda(Tensor a, int wl) { + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + auto rand_probs = rand_like(a); + int64_t size = a.numel(); + + Tensor max_entry = at::max(at::abs(a)); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + block_kernel_sim_stochastic<<>>(a.data_ptr(), + rand_probs.data_ptr(), + o.data_ptr(), + size, + max_entry.data_ptr(), + wl); + return o; +} + +Tensor block_quantize_sim_nearest_cuda(Tensor a, int wl) { + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + auto rand_ints = randint_like(a, INT_MAX, device(kCUDA).dtype(kInt)); + int64_t size = a.numel(); + + Tensor max_entry = at::max(at::abs(a)); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + block_kernel_sim_nearest<<>>(a.data_ptr(), + o.data_ptr(), + size, + max_entry.data_ptr(), + wl); + return o; +} + +Tensor float_quantize_stochastic_cuda(Tensor a, int man_bits, int exp_bits) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = zeros_like(a); + auto rand_ints = randint_like(a, INT_MAX, device(kCUDA).dtype(kInt)); + int size = a.numel(); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + float_kernel_stochastic<<>>(a.data_ptr(), + rand_ints.data_ptr(), + o.data_ptr(), + size, + man_bits, + exp_bits); + return o; +} + +Tensor float_quantize_nearest_cuda(Tensor a, int man_bits, int exp_bits) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = zeros_like(a); + int size = a.numel(); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + float_kernel_nearest<<>>(a.data_ptr(), + o.data_ptr(), + size, + man_bits, + exp_bits); + return o; +} + +void fixed_min_max(int wl, int fl, bool symmetric, float* t_min, float* t_max) { + int sigma = -fl; + *t_min = -ldexp(1.0, wl-fl-1); + *t_max = -*t_min-ldexp(1.0, sigma); + if (symmetric) *t_min = *t_min+ldexp(1.0, sigma); +} + +Tensor fixed_point_quantize_stochastic_cuda(Tensor a, int wl, int fl, bool use_clamp, bool symmetric) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + auto rand_probs = rand_like(a); + int64_t size = a.numel(); + int sigma = -fl; + float t_min, t_max; + fixed_min_max(wl, fl, symmetric, &t_min, &t_max); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + fixed_point_quantize_kernel_stochastic<<>>(a.data_ptr(), + rand_probs.data_ptr(), + o.data_ptr(), + size, + sigma, + use_clamp, + t_min, + t_max); + return o; +} + +Tensor fixed_point_quantize_nearest_cuda(Tensor a, int wl, int fl, bool use_clamp, bool symmetric) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + int64_t size = a.numel(); + int sigma = -fl; + float t_min, t_max; + fixed_min_max(wl, fl, symmetric, &t_min, &t_max); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + fixed_point_quantize_kernel_nearest<<>>(a.data_ptr(), + o.data_ptr(), + size, + sigma, + use_clamp, + t_min, + t_max); + return o; +} + +std::tuple fixed_point_quantize_stochastic_mask_cuda(Tensor a, int wl, int fl, bool symmetric) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = zeros_like(a); + auto rand_probs = rand_like(a); + auto m = zeros_like(a, a.options().dtype(kByte)); + int64_t size = a.numel(); + int sigma = -fl; + float t_min, t_max; + fixed_min_max(wl, fl, symmetric, &t_min, &t_max); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + fixed_point_quantize_kernel_mask_stochastic<<>>(a.data_ptr(), + rand_probs.data_ptr(), + o.data_ptr(), + m.data_ptr(), + size, + sigma, + t_min, + t_max); + return std::make_tuple(o, m); +} + +std::tuple fixed_point_quantize_nearest_mask_cuda(Tensor a, int wl, int fl, bool symmetric) { + // use external random number right now + cudaSetDevice(a.get_device()); + auto o = at::zeros_like(a); + auto m = zeros_like(a, a.options().dtype(kByte)); + int64_t size = a.numel(); + int sigma = -fl; + float t_min, t_max; + fixed_min_max(wl, fl, symmetric, &t_min, &t_max); + int blockSize = 1024; + int blockNums = (size + blockSize - 1) / blockSize; + + fixed_point_quantize_kernel_mask_nearest<<>>(a.data_ptr(), + o.data_ptr(), + m.data_ptr(), + size, + sigma, + t_min, + t_max); + return std::make_tuple(o, m); +} \ No newline at end of file diff --git a/qtorch/quant/quant_hip/quant_cuda.h b/qtorch/quant/quant_hip/quant_cuda.h new file mode 100644 index 0000000..ff6c49a --- /dev/null +++ b/qtorch/quant/quant_hip/quant_cuda.h @@ -0,0 +1,80 @@ +#include +#include + +using namespace at; + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl], with option of clamping the over/underflow numbers + * having a symmeric number range. + * Stochastic Rounding. + **/ +Tensor fixed_point_quantize_stochastic_cuda(Tensor a, int wl, int fl, bool use_clamp, bool symmetric); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl], with option of clamping the over/underflow numbers + * having a symmeric number range. + * Nearest Rounding. + **/ +Tensor fixed_point_quantize_nearest_cuda(Tensor a, int wl, int fl, bool use_clamp, bool symmetric); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl], clamp the over/underflow number and recording the clamping into a mask, + * with the option of having a symmetric number range + * Stochastic Rounding. + **/ +std::tuple fixed_point_quantize_stochastic_mask_cuda(Tensor a, int wl, int fl, bool symmetric); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl], clamp the over/underflow number and recording the clamping into a mask, + * with the option of having a symmetric number range + * Nearest Rounding. + **/ +std::tuple fixed_point_quantize_nearest_mask_cuda(Tensor a, int wl, int fl, bool symmetric); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl] + * Stochastic Rounding. + **/ +Tensor block_quantize_stochastic_cuda(Tensor a, int wl, int dim); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl] + * Nearest Rounding. + **/ +Tensor block_quantize_nearest_cuda(Tensor a, int wl, int dim); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl] + * Stochastic Rounding. + **/ +Tensor block_quantize_sim_stochastic_cuda(Tensor a, int wl); + +/** + * quantize a FloatTensor into fixed point number with word length [wl] + * and fractional bits [fl] + * Nearest Rounding. + **/ +Tensor block_quantize_sim_nearest_cuda(Tensor a, int wl); + +/** + * quantize a FloatTensor into a low bit-width floating point Tensor + * with [man_bits] mantissa bits and [exp_bits] exponent bits. + * Does not handle NaN, Inf, and denormal. + * Stochastic Rounding. + **/ +Tensor float_quantize_stochastic_cuda(Tensor a, int man_bits, int exp_bits); + +/** + * quantize a FloatTensor into a low bit-width floating point Tensor + * with [man_bits] mantissa bits and [exp_bits] exponent bits. + * Does not handle NaN, Inf, and denormal. + * Nearest Rounding. + **/ +Tensor float_quantize_nearest_cuda(Tensor a, int man_bits, int exp_bits); diff --git a/qtorch/quant/quant_hip/quant_hip.cpp b/qtorch/quant/quant_hip/quant_hip.cpp new file mode 100644 index 0000000..b650b14 --- /dev/null +++ b/qtorch/quant/quant_hip/quant_hip.cpp @@ -0,0 +1,87 @@ +#include +#include "quant_cuda.h" +#include + +using namespace at; + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool use_clamp, bool symmetric) +{ + CHECK_INPUT(a); + return fixed_point_quantize_nearest_cuda(a, wl, fl, use_clamp, symmetric); +} + +std::tuple fixed_point_quantize_nearest_mask(Tensor a, int wl, int fl, + bool symmetric) +{ + CHECK_INPUT(a); + return fixed_point_quantize_nearest_mask_cuda(a, wl, fl, symmetric); +} + +Tensor block_quantize_nearest(Tensor a, int wl, int dim) +{ + CHECK_INPUT(a); + return block_quantize_nearest_cuda(a, wl, dim); +} + +Tensor block_quantize_sim_nearest(Tensor a, int wl) +{ + CHECK_INPUT(a); + return block_quantize_sim_nearest_cuda(a, wl); +} + +Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits) +{ + CHECK_INPUT(a); + return float_quantize_nearest_cuda(a, man_bits, exp_bits); +} + +Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool use_clamp, bool symmetric) +{ + CHECK_INPUT(a); + return fixed_point_quantize_stochastic_cuda(a, wl, fl, use_clamp, symmetric); +} + +std::tuple fixed_point_quantize_stochastic_mask(Tensor a, int wl, int fl, + bool symmetric) +{ + CHECK_INPUT(a); + return fixed_point_quantize_stochastic_mask_cuda(a, wl, fl, symmetric); +} + +Tensor block_quantize_stochastic(Tensor a, int wl, int dim) +{ + CHECK_INPUT(a); + return block_quantize_stochastic_cuda(a, wl, dim); +} + +Tensor block_quantize_sim_stochastic(Tensor a, int wl) +{ + CHECK_INPUT(a); + return block_quantize_sim_stochastic_cuda(a, wl); +} + +Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits) +{ + CHECK_INPUT(a); + return float_quantize_stochastic_cuda(a, man_bits, exp_bits); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("fixed_point_quantize_stochastic", &fixed_point_quantize_stochastic, "Fixed Point Number Stochastic Quantization (CUDA)"); + m.def("fixed_point_quantize_stochastic_mask", &fixed_point_quantize_stochastic_mask, "Fixed Point Number Stochastic Quantization (CUDA)"); + m.def("block_quantize_stochastic", &block_quantize_stochastic, "Block Floating Point Number Stochastic Quantization (CUDA)"); + m.def("block_quantize_sim_stochastic", &block_quantize_sim_stochastic, "Block Floating Point Number Stochastic Quantization (CUDA)"); + m.def("float_quantize_stochastic", &float_quantize_stochastic, "Low-Bitwidth Floating Point Number Stochastic Quantization (CUDA)"); + m.def("fixed_point_quantize_nearest", &fixed_point_quantize_nearest, "Fixed Point Number Nearest Neighbor Quantization (CUDA)"); + m.def("fixed_point_quantize_nearest_mask", &fixed_point_quantize_nearest_mask, "Fixed Point Number Nearest Neighbor Quantization (CUDA)"); + m.def("block_quantize_nearest", &block_quantize_nearest, "Block Floating Point Number Nearest Neighbor Quantization (CUDA)"); + m.def("block_quantize_sim_nearest", &block_quantize_sim_nearest, "Block Floating Point Number Stochastic Quantization (CUDA)"); + m.def("float_quantize_nearest", &float_quantize_nearest, "Low-Bitwidth Floating Point Number Nearest Neighbor Quantization (CUDA)"); +} diff --git a/qtorch/quant/quant_hip/quant_kernel.h b/qtorch/quant/quant_hip/quant_kernel.h new file mode 100644 index 0000000..9235bd8 --- /dev/null +++ b/qtorch/quant/quant_hip/quant_kernel.h @@ -0,0 +1,54 @@ +#include + +__global__ void fixed_point_quantize_kernel_stochastic(float *__restrict__ a, + float *__restrict__ r, + float *o, int size, + int sigma, bool clamp, + float t_min, float t_max); + +__global__ void fixed_point_quantize_kernel_nearest(float *__restrict__ a, + float *o, int size, + int sigma, bool clamp, + float t_min, float t_max); + +__global__ void fixed_point_quantize_kernel_mask_stochastic(float *__restrict__ a, + float *__restrict__ r, + float *o, uint8_t *mask, + int size, int sigma, + float t_min, float t_max); + +__global__ void fixed_point_quantize_kernel_mask_nearest(float *__restrict__ a, + float *o, uint8_t *mask, + int size, int sigma, + float t_min, float t_max); + +__global__ void float_kernel_stochastic(float *__restrict__ a, + int *__restrict__ r, + float *o, int size, + int man_bits, int exp_bits); + +__global__ void float_kernel_nearest(float *__restrict__ a, + float *o, int size, + int man_bits, int exp_bits); + +__global__ void block_kernel_stochastic(float *__restrict__ a, + int *__restrict__ r, + float *o, int size, + float *max_entry, + int man_bits); + +__global__ void block_kernel_nearest(float *__restrict__ a, + float *o, int size, + float *max_entry, + int man_bits); + +__global__ void block_kernel_sim_stochastic(float *__restrict__ a, + float *__restrict__ r, + float *o, int size, + float *max_entry, + int wl); + +__global__ void block_kernel_sim_nearest(float *__restrict__ a, + float *o, int size, + float *max_entry, + int wl); diff --git a/qtorch/quant/quant_hip/sim_helper.hip b/qtorch/quant/quant_hip/sim_helper.hip new file mode 100644 index 0000000..b6479bf --- /dev/null +++ b/qtorch/quant/quant_hip/sim_helper.hip @@ -0,0 +1,24 @@ +#include "quant_kernel.h" +#include + +__device__ __inline__ float round_helper(float a, float r) { + // return floor(a+r); + return nearbyint(a+r-0.5); +} + +__device__ __inline__ float round(float a, float r, int sigma) { + a = ldexp(a, -sigma); + a = round_helper(a, r); + a = ldexp(a, sigma); + return a; +} + +__device__ __inline__ float nearest_round(float a, int sigma) { + a = ldexp(a, -sigma); + // a = nearbyint(a); + a = round(a); + // a = floor(a+0.5); + //a = ceil(a-0.5); + a = ldexp(a, sigma); + return a; +} diff --git a/run_all_tests.sh b/run_all_tests.sh new file mode 100755 index 0000000..472da10 --- /dev/null +++ b/run_all_tests.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Script to run all QPyTorch tests with ROCm support +# Tests both CPU and GPU (CUDA/ROCm) functionality + +set -e # Exit on error + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Setup log file +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOGFILE="test_results_${TIMESTAMP}.log" + +# Function to log to both console and file +log() { + echo -e "$@" | tee -a "$LOGFILE" +} + +log "==========================================" +log "QPyTorch Test Suite (ROCm Fork)" +log "==========================================" +log "Log file: $LOGFILE" +log "" + +# Check if PyTorch detects GPU +if python -c "import torch; print(torch.cuda.is_available())" | grep -q "True"; then + log -e "${GREEN}✓${NC} GPU detected ($(python -c "import torch; print(torch.cuda.get_device_name(0))"))" + GPU_AVAILABLE=true +else + log -e "${YELLOW}!${NC} No GPU detected - will run CPU tests only" + GPU_AVAILABLE=false +fi +log "" + +# Array to track test results +declare -a PASSED_TESTS +declare -a FAILED_TESTS + +# Function to run a single test +run_test() { + local test_file=$1 + local test_name=$(basename "$test_file" .py) + + log "----------------------------------------" + log "Running: $test_name" + log "----------------------------------------" + + if python -m pytest "$test_file" -v 2>&1 | tee -a "$LOGFILE"; then + log -e "${GREEN}✓ PASSED${NC}: $test_name" + PASSED_TESTS+=("$test_name") + else + log -e "${RED}✗ FAILED${NC}: $test_name" + FAILED_TESTS+=("$test_name") + return 1 + fi +} + +# Run all tests in the test directory +cd "$(dirname "$0")" + +log "Starting test execution..." +log "" + +for test_file in test/test_*.py; do + if [ -f "$test_file" ]; then + run_test "$test_file" || true # Continue even if test fails + log "" + fi +done + +# Summary +log "==========================================" +log "Test Summary" +log "==========================================" +log "" + +if [ ${#PASSED_TESTS[@]} -gt 0 ]; then + log -e "${GREEN}Passed Tests (${#PASSED_TESTS[@]}):${NC}" + for test in "${PASSED_TESTS[@]}"; do + log " ✓ $test" + done + log "" +fi + +if [ ${#FAILED_TESTS[@]} -gt 0 ]; then + log -e "${RED}Failed Tests (${#FAILED_TESTS[@]}):${NC}" + for test in "${FAILED_TESTS[@]}"; do + log " ✗ $test" + done + log "" + log "Full log saved to: $LOGFILE" + exit 1 +else + log -e "${GREEN}All tests passed!${NC}" + log "Full log saved to: $LOGFILE" + exit 0 +fi