From 9aa8deb79d06d529daed94d34cda840840fc9c76 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 11 Feb 2026 23:44:27 +0000 Subject: [PATCH 01/11] fix build errors --- .../cuda/bert/group_query_attention_impl.cu | 12 ++++++++++-- onnxruntime/core/providers/cuda/llm/attention.cc | 5 +++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index e0fa993db29bd..59e2be5e8cd4b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -612,8 +612,8 @@ Status ExtremeDecoding( batch_size, parameters.seqlen_present_kv_cache, // max_seqlen (capacity) data.past_seq_lens, - data.cos_cache, - data.sin_cache, + reinterpret_cast(data.cos_cache), + reinterpret_cast(data.sin_cache), parameters.do_rotary ? parameters.rotary_dim : 0, data.position_ids, parameters.rotary_interleaved, @@ -1105,6 +1105,7 @@ Status QkvToContext( template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; +template struct GroupQueryAttentionData; template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1121,6 +1122,13 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data); +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index fb0ca5248191f..6c235f95aabcf 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -196,7 +196,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { gqa_parameters.num_splits = 1; // Construct GroupQueryAttentionData - onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; + onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; // Scratch buffers for flash/memory efficient attention IAllocatorUniquePtr k_buffer; @@ -355,6 +355,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Centralized scratch buffer allocation using GQABufferRequirements auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute( gqa_parameters, + false, // use_xqa gqa_data.use_flash_attention, gqa_data.use_flash_attention_fast_decode, gqa_data.use_memory_efficient_attention); @@ -478,7 +479,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Call GQA kernel (with flash or memory efficient attention) cublasHandle_t cublas = GetCublasHandle(context); - return onnxruntime::contrib::cuda::QkvToContext( + return onnxruntime::contrib::cuda::QkvToContext( device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data); } From 019a2b157314fbf5d71e148023f3750d25e96e4b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 00:42:35 +0000 Subject: [PATCH 02/11] Support fp8 kv cache --- cmake/CMakeLists.txt | 15 ++ .../cuda/bert/group_query_attention.cc | 26 +- .../cuda/bert/group_query_attention_impl.cu | 166 +++++++------ .../cuda/bert/group_query_attention_qdq.cuh | 62 ++++- .../cuda/bert/group_query_attention_qkv.cuh | 92 ++++--- onnxruntime/contrib_ops/cuda/bert/xqa/mha.h | 2 +- .../cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu | 9 + .../cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu | 9 + .../cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu | 9 + .../bert/xqa/xqa_loader_bf16_fp8_impl.cuh | 120 +++++++++ .../cuda/bert/xqa/xqa_loader_bf16_impl.cuh | 32 +++ .../cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu | 9 + .../cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu | 9 + .../cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu | 9 + .../bert/xqa/xqa_loader_fp16_fp8_impl.cuh | 119 +++++++++ .../cuda/bert/xqa/xqa_loader_fp16_impl.cuh | 34 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 + .../core/graph/contrib_ops/bert_defs.cc | 10 +- .../tools/transformers/io_binding_helper.py | 52 ++++ .../test/python/transformers/benchmark_gqa.py | 32 ++- .../python/transformers/gqa_test_helper.py | 41 ++- .../test/python/transformers/test_gqa.py | 234 ++++++++++++++++-- 22 files changed, 949 insertions(+), 150 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 11244e46b78a0..6cfee9e495451 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF) option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF) +option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON) option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA) message( STATUS "Enable int4 kv cache for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1) endif() + + if (onnxruntime_USE_FP8_KV_CACHE) + message( STATUS "Enable fp8 kv cache for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1) + endif() endif() if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA)) @@ -1442,6 +1448,15 @@ if (Git_FOUND) if (onnxruntime_USE_INT4_KV_CACHE) string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ") endif() + if (onnxruntime_USE_FP8_KV_CACHE) + string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ") + endif() + if (onnxruntime_DUMP_TENSOR) + string(APPEND ORT_BUILD_INFO "dump-tensor=1, ") + endif() + if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) + string(APPEND ORT_BUILD_INFO "dump-node=1, ") + endif() endif() string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}") configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 39154ca395fc1..a965e00f6a391 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16) REGISTER_KERNEL_TYPED(BFloat16, BFloat16) REGISTER_KERNEL_TYPED(MLFloat16, int8_t) REGISTER_KERNEL_TYPED(BFloat16, int8_t) +#ifdef USE_FP8_KV_CACHE +REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN) +REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN) +#endif #ifdef USE_INT4_KV_CACHE REGISTER_KERNEL_TYPED(MLFloat16, uint8_t) REGISTER_KERNEL_TYPED(BFloat16, uint8_t) @@ -292,6 +296,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.past_present_share_buffer = (data.past_key == data.present_key); bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); + constexpr bool is_int8 = std::is_same::value; + constexpr bool is_fp8 = std::is_same::value; // Allocate XQA scratch if needed (only for Flash Decoding path) IAllocatorUniquePtr xqa_scratch_buffer; @@ -315,18 +321,30 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.local_window_size == -1) { int group_size = parameters.num_heads / parameters.kv_num_heads; - bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR && + bool is_int8_quantized_supported = is_int8 && + (k_quant_type_ == KVQuantizationType::PER_TENSOR && v_quant_type_ == KVQuantizationType::PER_TENSOR && data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor. - parameters.kv_cache_bit_width == 8 && (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32)); +#ifdef USE_FP8_KV_CACHE + bool is_fp8_quantized_supported = is_fp8 && + (k_quant_type_ == KVQuantizationType::PER_TENSOR && + v_quant_type_ == KVQuantizationType::PER_TENSOR && + data.k_scale == data.v_scale && + (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && + (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) && + (device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace) +#else + constexpr bool is_fp8_quantized_supported = false; +#endif + bool is_non_quantized_supported = !is_inputs_quantized && (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && (64 % group_size == 0); - data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported); + data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported); if (data.use_xqa) { size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize( @@ -336,7 +354,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.kv_num_heads, parameters.head_size, parameters.seqlen_present_kv_cache, - parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone, + parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone, std::is_same::value); assert(xqa_internal_bytes > 0); // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 59e2be5e8cd4b..dc0fc1b32b2bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -98,10 +98,9 @@ Status PrepareQKV( q_out = nullptr; } - CudaT* k = reinterpret_cast(data.present_key); - CudaT* v = reinterpret_cast(data.present_value); + CudaU* k = reinterpret_cast(data.present_key); + CudaU* v = reinterpret_cast(data.present_value); int max_cache_length = parameters.seqlen_present_kv_cache; - bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); if (!parameters.past_present_share_buffer) { size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU); @@ -109,32 +108,22 @@ Status PrepareQKV( CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); } + bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(is_cache_bnsh); // Only support BNSH format for now + // Copy past KV to present KV if needed if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { - if (is_cache_bnsh) { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaU); - size_t width = src_pitch; - size_t height = (size_t)batch_size * kv_num_heads; - - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - } else { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaU); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaU); - size_t width = src_pitch; - size_t height = (size_t)batch_size; - - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, - cudaMemcpyDeviceToDevice, stream)); - } + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU); + size_t dst_pitch = (size_t)max_cache_length * head_size * sizeof(CudaU); + size_t width = src_pitch; + size_t height = (size_t)batch_size * kv_num_heads; + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); } - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), @@ -144,8 +133,8 @@ Status PrepareQKV( max_cache_length, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, - is_cache_bnsh, parameters.k_quant_type, parameters.kv_cache_bit_width, - stream, max_threads_per_block)); + is_cache_bnsh, parameters.k_quant_type, + stream, max_threads_per_block))); if (q_out != nullptr) { q = reinterpret_cast(q_out); @@ -585,6 +574,7 @@ Status ExtremeDecoding( // bool is_bf16 = std::is_same::value; typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V @@ -595,14 +585,14 @@ Status ExtremeDecoding( q_input_for_xqa = reinterpret_cast(data.query); } - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_rot_ptr, // unpacked_q (can be null if !do_rotary) - data.present_key, - data.present_value, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), data.k_scale, data.v_scale, num_heads, @@ -619,14 +609,15 @@ Status ExtremeDecoding( parameters.rotary_interleaved, !past_bsnh, // is_cache_bnsh parameters.k_quant_type, - parameters.kv_cache_bit_width, stream, - device_prop.maxThreadsPerBlock)); + device_prop.maxThreadsPerBlock))); // Determine workspace size for XQA void* xqa_workspace = data.xqa_buffer; size_t xqa_workspace_size = data.xqa_buffer_bytes; + constexpr bool is_fp8 = std::is_same::value; + using onnxruntime::contrib::cuda::XqaQuantType; // 5. Launch XQA Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( device_prop, @@ -644,8 +635,8 @@ Status ExtremeDecoding( past_bsnh, data.past_seq_lens, data.k_scale, // kv_cache_scale - // Map KVQuantizationType (0=NONE, 1=TENSOR, 2=CHANNEL) to XqaQuantType (0=FP16/BF16, 1=INT8, 2=FP8) - (parameters.k_quant_type == KVQuantizationType::NONE) ? onnxruntime::contrib::cuda::XqaQuantType::kNone : onnxruntime::contrib::cuda::XqaQuantType::kInt8, + // Map cache type to XqaQuantType: NONE->kNone, Float8E4M3FN->kFp8, int8->kInt8 + (parameters.k_quant_type == KVQuantizationType::NONE) ? XqaQuantType::kNone : (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8), xqa_workspace, xqa_workspace_size); @@ -806,6 +797,7 @@ Status DequantizeFlashAttentionFallback( // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer). // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized] typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; CudaT* q_rot = reinterpret_cast(data.qkv_buffer); size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size; size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size; @@ -815,48 +807,34 @@ Status DequantizeFlashAttentionFallback( // Step 1: Update Quantized Cache // We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache. // This will also put rotated Q into q_rot. - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), - q_rot, data.present_key, data.present_value, data.k_scale, data.v_scale, + q_rot, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), + data.k_scale, data.v_scale, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size, parameters.seqlen_present_kv_cache, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), - parameters.k_quant_type, parameters.kv_cache_bit_width, - stream, device_prop.maxThreadsPerBlock)); + parameters.k_quant_type, + stream, device_prop.maxThreadsPerBlock))); // Step 2: Dequantize Entire Cache // We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant. bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 8, parameters.k_quant_type, is_bsnh))); + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, is_bsnh))); - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 8, parameters.v_quant_type, is_bsnh))); -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - // Int4 support if needed - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 4, parameters.k_quant_type, is_bsnh))); - - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, - nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, - parameters.head_size, 4, parameters.v_quant_type, is_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, is_bsnh))); // Step 3: Run Flash Attention on dequantized k/v bool is_causal = parameters.is_unidirectional; @@ -913,7 +891,7 @@ Status FlashAttentionAndQuantizeKV( CudaT* k_final = q_final + q_elements; CudaT* v_final = k_final + k_elements; - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), @@ -925,8 +903,7 @@ Status FlashAttentionAndQuantizeKV( parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, false, // BSNH for scratch KVQuantizationType::NONE, - 0, // bit_width is 0 since we are not quantizing here. - stream, max_threads_per_block)); + stream, max_threads_per_block))); // 2. Run Float Flash Attention bool is_causal = parameters.is_unidirectional; @@ -945,13 +922,23 @@ Status FlashAttentionAndQuantizeKV( true, // kv_bsnh = true (BSNH) local_window_size)); - // 3. Quantize K and V to present cache if (parameters.k_quant_type != KVQuantizationType::NONE) { if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.k_quant_type, true, past_bsnh))); +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast<__nv_fp8_e4m3*>(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.k_quant_type, true, past_bsnh))); + } else { +#endif + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.k_quant_type, true, past_bsnh))); +#ifdef USE_FP8_KV_CACHE + } +#endif #ifdef USE_INT4_KV_CACHE } else if (parameters.kv_cache_bit_width == 4) { ORT_RETURN_IF_ERROR((LaunchQuantizeKV( @@ -964,10 +951,21 @@ Status FlashAttentionAndQuantizeKV( if (parameters.v_quant_type != KVQuantizationType::NONE) { if (parameters.kv_cache_bit_width == 8) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.v_quant_type, true, past_bsnh))); +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast<__nv_fp8_e4m3*>(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.v_quant_type, true, past_bsnh))); + } else { +#endif + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.v_quant_type, true, past_bsnh))); +#ifdef USE_FP8_KV_CACHE + } +#endif #ifdef USE_INT4_KV_CACHE } else if (parameters.kv_cache_bit_width == 4) { ORT_RETURN_IF_ERROR((LaunchQuantizeKV( @@ -1145,6 +1143,7 @@ template Status QkvToContext<__nv_bfloat16, int8_t>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, int8_t>& data); +#ifdef USE_INT4_KV_CACHE template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1162,6 +1161,27 @@ template Status QkvToContext<__nv_bfloat16, uint8_t>( Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, uint8_t>& data); +#endif + +#ifdef USE_FP8_KV_CACHE +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>; + +template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data); +#endif template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh index 3aa9d6d96789a..a16dc46046951 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh @@ -6,6 +6,7 @@ #define KV_QUANT_SUPPORTED 1 #include +#include #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -25,6 +26,7 @@ constexpr int kInt4Max = 7; constexpr int kInt8Min = -128; constexpr int kInt8Max = 127; constexpr int kInt4ZeroPacked = 0x88; // (0 + 8) | ((0 + 8) << 4) for INT4 zero padding +constexpr float kFp8E4M3Max = 448.0f; // Max value for E4M3 format constexpr int kThreadsPerBlock = 256; template @@ -86,7 +88,12 @@ struct TypeConverter<__nv_bfloat16> { // ------------- // Cache: BNSH (batch, num_heads, sequence_length, head_size) // INT4: (head_size + 1) / 2 bytes per head -// INT8: head_size bytes per head +// INT8/FP8: head_size bytes per head +// +// FP8 E4M3: Native CUDA FP8 format +// - Range: [-448, 448] +// - Storage: __nv_fp8_e4m3 (1 byte) +// - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float // ============================================================================ // Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T). @@ -143,7 +150,14 @@ __global__ void DequantizeKernel(T* dequantized_data, (bit_width == 4 ? h / 2 : h); } - if (bit_width == 8) { + // FP8 check must come first since it also has bit_width=8 +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + __nv_fp8_e4m3 fp8_val = reinterpret_cast(quantized_data)[input_idx]; + quantized_float = static_cast(fp8_val); + } else +#endif + if (bit_width == 8) { quantized_float = static_cast( reinterpret_cast(quantized_data)[input_idx]); #ifdef USE_INT4_KV_CACHE @@ -231,6 +245,21 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, h_idx; } reinterpret_cast(quantized_data)[out_idx] = 0; +#ifdef USE_FP8_KV_CACHE + } else if constexpr (std::is_same::value) { // FP8 + int64_t out_idx = i; + if (is_output_bsnh) { + int64_t b_idx = b; + int64_t n_idx = n; + int64_t s_idx = s; + int64_t h_idx = i % elements_per_head_packed; + out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + + s_idx * num_heads * elements_per_head_packed + + n_idx * elements_per_head_packed + + h_idx; + } + reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[out_idx] = __nv_fp8_e4m3(0.0f); +#endif #ifdef USE_INT4_KV_CACHE } else if (bit_width == 4) { // INT4 // With packed iteration, each thread handles one byte (2 values). @@ -271,7 +300,34 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, h_idx; } - if (bit_width == 8) { +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + int h = h_packed; + float scale_val = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale_val = static_cast(scale[0]); + } else { // PER_CHANNEL + int scale_idx = n * head_size + h; + scale_val = static_cast(scale[scale_idx]); + } + + float inv_scale = (scale_val == 0.0f) ? 0.0f : 1.0f / scale_val; + int64_t flattened_input_idx = is_input_bsnh ? ((int64_t)b * input_sequence_length * num_heads * head_size + + (int64_t)s * num_heads * head_size + + (int64_t)n * head_size + + h) + : ((int64_t)b * num_heads * input_sequence_length * head_size + + (int64_t)n * input_sequence_length * head_size + + (int64_t)s * head_size + + h); + float val_float = static_cast(dequantized_data[flattened_input_idx]) * inv_scale; + + // Clamp to FP8 E4M3 range and convert + val_float = fmaxf(-kFp8E4M3Max, fminf(kFp8E4M3Max, val_float)); + reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[output_idx] = __nv_fp8_e4m3(val_float); + } else +#endif + if (bit_width == 8) { int h = h_packed; float scale_val = 1.0f; if (quant_type == KVQuantizationType::PER_TENSOR) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index d5c95be316a1f..851cc35018e2b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -3,6 +3,10 @@ #pragma once #include +#include +#ifdef USE_FP8_KV_CACHE +#include +#endif #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -28,18 +32,19 @@ namespace cuda { // 4. Writes the rotated Q back to global memory (unpacked_q) for the subsequent attention kernel. // // Template Parameters: -// - T: The floating point type (half or BFloat16). -// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8, 4=Int4). +// - T: The floating point type for query (half or BFloat16). +// - U: The cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8). +// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8/FP8, 4=Int4). // - MAX_HEAD_SIZE: Maximum supported head size, used for shared memory allocation. -template +template __global__ void UnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, T* unpacked_q, - void* k_cache, - void* v_cache, + U* k_cache, + U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, @@ -200,17 +205,32 @@ __global__ void UnpackRoPEAppend( // No quantization: direct store reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); } else if constexpr (BIT_WIDTH == 8) { - // Int8 Quantization: 1 element per byte + // 8-bit quantization: either INT8 or FP8 E4M3 based on cache type U const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; uint64_t packed = 0; - for (int i = 0; i < elements_per_thread; ++i) { - float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; - float inv_s = (sc == 0.0f) ? 0.0f : 1.0f / sc; - int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * inv_s)))); - packed |= (static_cast(static_cast(q)) << (i * 8)); +#ifdef USE_FP8_KV_CACHE + if constexpr (std::is_same::value) { + // FP8 E4M3 Quantization: scale and convert to FP8 format + constexpr float kFp8E4M3Max = 448.0f; + for (int i = 0; i < 8; ++i) { + float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; + float scaled_val = min(kFp8E4M3Max, max(-kFp8E4M3Max, static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))); + __nv_fp8_e4m3 fp8_val = __nv_fp8_e4m3(scaled_val); + packed |= (static_cast(*reinterpret_cast(&fp8_val)) << (i * 8)); + } + } else +#endif + { + // INT8 Quantization: round and clamp to [-128, 127] + for (int i = 0; i < 8; ++i) { + float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; + int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))))); + packed |= (static_cast(static_cast(q)) << (i * 8)); + } } // Store 8 elements (8 bytes) at once - reinterpret_cast(cache_ptr)[cache_idx / 8] = packed; + unsigned char* cache_byte_ptr = reinterpret_cast((head_type == KEY) ? k_cache : v_cache); + reinterpret_cast(cache_byte_ptr + cache_idx)[0] = packed; } else if constexpr (BIT_WIDTH == 4) { // Int4 Quantization: 2 elements per byte constexpr float kInt4Min = -8.0f; @@ -237,28 +257,28 @@ __global__ void UnpackRoPEAppend( // Internal dispatcher that selects the appropriate template specialization based on head_size. // MAX_HEAD_SIZE is used to optimize shared memory usage and kernel performance. -template +template Status DispatchUnpackRoPEAppendHeadSize( const dim3& grid, const dim3& block, cudaStream_t stream, const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, + T* unpacked_q, U* k_cache, U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int d, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) { if (head_size <= 64) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); } else if (head_size <= 128) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); } else if (head_size <= 256) { - UnpackRoPEAppend<<>>( + UnpackRoPEAppend<<>>( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); @@ -269,18 +289,21 @@ Status DispatchUnpackRoPEAppendHeadSize( } // Public entry point to launch the Unpack+RoPE+Append kernel. -// Handles parameter validation, grid/block sizing, and bit-width dispatching. -template +// Handles parameter validation, grid/block sizing, and type-based dispatching. +// Template parameters: +// - T: Query/Key/Value floating point type (half or BFloat16) +// - U: Cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8) +template Status LaunchUnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, + T* unpacked_q, U* k_cache, U* v_cache, const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, - const int bit_width, cudaStream_t stream, const int max_threads_per_block) { + cudaStream_t stream, const int max_threads_per_block) { constexpr int elements_per_vector = sizeof(float4) / sizeof(T); if (head_size % elements_per_vector != 0) { @@ -315,26 +338,37 @@ Status LaunchUnpackRoPEAppend( bool per_channel = (k_quant_type == KVQuantizationType::PER_CHANNEL); - if (bit_width == 0) { - return DispatchUnpackRoPEAppendHeadSize( + // Dispatch based on cache type U: + // - std::is_same: No quantization (BIT_WIDTH=16) + // - std::is_same or FP8: 8-bit quantization (BIT_WIDTH=8) + // - std::is_same: 4-bit quantization (BIT_WIDTH=4) + if constexpr (std::is_same::value) { + // No quantization: cache type same as input type + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); - } else if (bit_width == 8) { - return DispatchUnpackRoPEAppendHeadSize( + } else if constexpr (std::is_same::value +#ifdef USE_FP8_KV_CACHE + || std::is_same::value +#endif + ) { + // INT8 or FP8 quantization (both 8-bit, distinguished inside kernel by type check) + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); #ifdef USE_INT4_KV_CACHE - } else if (bit_width == 4) { - return DispatchUnpackRoPEAppendHeadSize( + } else if constexpr (std::is_same::value) { + // INT4 quantization (packed 2 elements per byte) + return DispatchUnpackRoPEAppendHeadSize( grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); #endif + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported cache type U for GQA quantization."); } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported bit_width (", bit_width, ") for GQA quantization."); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h index 5aa78aa242306..d803cb6fba531 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h @@ -50,7 +50,7 @@ constexpr uint32_t inputSeqLen = 1; // speculative decoding if > 1 constexpr bool useKVCache = USE_KV_CACHE; using SeqLenDataType = uint32_t; -#endif +#endif // MHA_H_COMMON // Dependent definitions #ifndef MHA_H_DEPENDENT diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu new file mode 100644 index 0000000000000..612f2fd14f09a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_128.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu new file mode 100644 index 0000000000000..9329679593e7c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_256.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu new file mode 100644 index 0000000000000..d3144b5bb7e2b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_64.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_bf16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh new file mode 100644 index 0000000000000..481fcb63c1f8c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_fp8_impl.cuh @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_bf16_fp8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_bf16_fp8_impl.cuh" +#endif + +// Define global constants for FP8 E4M3 KV Cache with BF16 Query +#define CACHE_ELEM_ENUM 2 // FP8 E4M3 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 0 // Q is BF16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// FP8 E4M3 KV Cache Instantiations for BF16 Query +// ============================================================================ + +#define NAMESPACE_NAME grp4_bf16_fp8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_bf16_fp8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_bf16_fp8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_bf16_fp8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAFp8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_bf16_fp8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh index 644dec2c67bbd..c2d9c057c6e50 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -116,6 +116,28 @@ Status LaunchXQAInt8KernelBF16( void* workspace, size_t workspace_size); +#ifdef USE_FP8_KV_CACHE +// Extern declarations for FP8 kernels with BF16 query (implemented in xqa_loader_bf16_fp8_impl.cuh) +Status LaunchXQAFp8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +#endif + // ============================================================================ // Specialization for BFloat16 // ============================================================================ @@ -171,6 +193,16 @@ Status LaunchXQAKernelImpl<__nv_bfloat16>( workspace_size); } +#ifdef USE_FP8_KV_CACHE + // Dispatch to FP8 path if requested + if (kv_quant_type == XqaQuantType::kFp8) { + return LaunchXQAFp8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, + batch_size, num_heads, kv_num_heads, head_size, max_seq_len, + scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, + workspace_size); + } +#endif + int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu new file mode 100644 index 0000000000000..f9697fdd2f614 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_128.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu new file mode 100644 index 0000000000000..3f5d9ac3f5507 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_256.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu new file mode 100644 index 0000000000000..ce894ebc384a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_64.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#ifdef USE_FP8_KV_CACHE +#include "xqa_loader_fp16_fp8_impl.cuh" +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh new file mode 100644 index 0000000000000..5e18d21defb79 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_fp8_impl.cuh @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_fp16_fp8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_fp16_fp8_impl.cuh" +#endif + +// Define global constants for FP8 E4M3 KV Cache +#define CACHE_ELEM_ENUM 2 // FP8 E4M3 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 1 // Q is FP16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// FP8 E4M3 KV Cache Instantiations for FP16 Query +// ============================================================================ +#define NAMESPACE_NAME grp4_fp8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_fp8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_fp8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_fp8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAFp8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_fp8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh index 8ba0fe3b1ee0d..675beb3c92d0f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -116,6 +116,28 @@ Status LaunchXQAInt8Kernel( void* workspace, size_t workspace_size); +#ifdef USE_FP8_KV_CACHE +// Extern declarations for FP8 kernels (implemented in xqa_loader_fp16_fp8_impl.cuh via instantiation) +Status LaunchXQAFp8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +#endif + // ============================================================================ // Dispatcher Implementation // ============================================================================ @@ -152,6 +174,18 @@ Status LaunchXQAKernelImpl( } } +#ifdef USE_FP8_KV_CACHE + // Dispatch to FP8 path if requested + if (kv_quant_type == XqaQuantType::kFp8) { + if constexpr (std::is_same::value) { + return LaunchXQAFp8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else { + // BF16 case is handled in xqa_loader_bf16.cu via specialization + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA FP8 path mismatch."); + } + } +#endif + int group_size = num_heads / kv_num_heads; switch (group_size) { case 1: diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index ab692e0549d6c..e73ad25d96f38 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -115,6 +115,10 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_int8_t, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_uint8_t, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_uint8_t, GroupQueryAttention); #endif +#ifdef USE_FP8_KV_CACHE +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_Float8E4M3FN, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_Float8E4M3FN, GroupQueryAttention); +#endif class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention); @@ -361,6 +365,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { #ifdef USE_INT4_KV_CACHE BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif +#ifdef USE_FP8_KV_CACHE + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index fee5d9556e75b..092c05f9e081a 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -342,8 +342,14 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte : past_dims[2].dim_value(); present_shape.add_dim()->set_dim_value(present_sequence_length); } else { - // Cannot compute exact present_sequence_length, copy from past_key (may be dynamic) - *present_shape.add_dim() = past_dims[2]; + // Cannot compute exact present_sequence_length. + if (ctx.getNumInputs() > 6 && past_dims[2].has_dim_value() && past_dims[2].dim_value() == 0) { + // If total_sequence_length is provided and past_key has 0 length, present_key will grow. + // Leave the dimension as dynamic to avoid "Error merging shape info" warning. + present_shape.add_dim(); + } else { + *present_shape.add_dim() = past_dims[2]; + } } *present_shape.add_dim() = past_dims[3]; // head_size diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 3caaca9663c2e..e71f341e1e818 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -42,6 +42,13 @@ def ort_type_to_numpy_type(ort_type: str): "tensor(bool)": bool, "tensor(uint8)": numpy.uint8, "tensor(int8)": numpy.int8, + "tensor(double)": numpy.float64, + "tensor(int16)": numpy.int16, + "tensor(uint16)": numpy.uint16, + "tensor(uint32)": numpy.uint32, + "tensor(uint64)": numpy.uint64, + "tensor(complex64)": numpy.complex64, + "tensor(complex128)": numpy.complex128, } if ort_type not in ort_type_to_numpy_type_map: raise ValueError(f"{ort_type} not found in map") @@ -59,6 +66,19 @@ def ort_type_to_torch_type(ort_type: str): "tensor(bool)": torch.bool, "tensor(uint8)": torch.uint8, "tensor(int8)": torch.int8, + "tensor(double)": torch.float64, + "tensor(int16)": torch.int16, + "tensor(uint16)": torch.uint16, + "tensor(uint32)": torch.uint32, + "tensor(uint64)": torch.uint64, + "tensor(complex64)": torch.complex64, + "tensor(complex128)": torch.complex128, + "tensor(float8e4m3fn)": torch.float8_e4m3fn, + "tensor(float8e4m3fnuz)": torch.float8_e4m3fnuz, + "tensor(float8e5m2)": torch.float8_e5m2, + "tensor(float8e5m2fnuz)": torch.float8_e5m2fnuz, + "tensor(int4)": torch.int4, + "tensor(uint4)": torch.uint4, } if ort_type not in ort_type_to_torch_type_map: raise ValueError(f"{ort_type} not found in map") @@ -87,6 +107,21 @@ def ort_type_to_onnx_type(ort_type: str): "tensor(bool)": TensorProto.BOOL, "tensor(uint8)": TensorProto.UINT8, "tensor(int8)": TensorProto.INT8, + "tensor(double)": TensorProto.DOUBLE, + "tensor(int16)": TensorProto.INT16, + "tensor(uint16)": TensorProto.UINT16, + "tensor(uint32)": TensorProto.UINT32, + "tensor(uint64)": TensorProto.UINT64, + "tensor(complex64)": TensorProto.COMPLEX64, + "tensor(complex128)": TensorProto.COMPLEX128, + "tensor(float8e4m3fn)": TensorProto.FLOAT8E4M3FN, + "tensor(float8e4m3fnuz)": TensorProto.FLOAT8E4M3FNUZ, + "tensor(float8e5m2)": TensorProto.FLOAT8E5M2, + "tensor(float8e5m2fnuz)": TensorProto.FLOAT8E5M2FNUZ, + "tensor(float4e2m1)": TensorProto.FLOAT4E2M1, + "tensor(int4)": TensorProto.INT4, + "tensor(uint4)": TensorProto.UINT4, + "tensor(string)": TensorProto.STRING, } if ort_type not in ort_type_to_onnx_type_map: raise ValueError(f"{ort_type} not found in map") @@ -104,7 +139,15 @@ def numpy_type_to_torch_type(numpy_type: numpy.dtype): bool: torch.bool, numpy.uint8: torch.uint8, numpy.int8: torch.int8, + numpy.float64: torch.float64, + numpy.int16: torch.int16, + numpy.uint16: torch.int32, + numpy.uint32: torch.int64, + numpy.uint64: torch.int64, + numpy.complex64: torch.complex64, + numpy.complex128: torch.complex128, } + if numpy_type not in numpy_type_to_torch_type_map: raise ValueError(f"{numpy_type} not found in map") @@ -119,7 +162,16 @@ def torch_type_to_numpy_type(torch_type: torch.dtype): torch.float16: numpy.float16, torch.bool: bool, torch.uint8: numpy.uint8, + torch.int8: numpy.int8, + torch.float64: numpy.float64, + torch.int16: numpy.int16, + torch.uint16: numpy.uint16, + torch.uint32: numpy.uint32, + torch.uint64: numpy.uint64, + torch.complex64: numpy.complex64, + torch.complex128: numpy.complex128, } + if torch_type not in torch_type_to_numpy_type_map: raise ValueError(f"{torch_type} not found in map") diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index c44d6b606d3a2..3a835d0852a9d 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -22,6 +22,7 @@ class TestConfig: test_int4: bool = False test_int8: bool = False + test_fp8: bool = False def get_plot_algos(sm: int, local_window_size: int | None, config: TestConfig | None): @@ -37,17 +38,21 @@ def get_plot_algos(sm: int, local_window_size: int | None, config: TestConfig | # Add quantized variants if requested if sm >= 80 and config: - quant_vals = ["ort_gqa_int4", "ort_gqa_int8"] - quant_names = ["ORT-GQA-INT4", "ORT-GQA-INT8"] - quant_styles = [("purple", "dotted"), ("orange", "dashdot")] + quant_vals = ["ort_gqa_int4", "ort_gqa_int8", "ort_gqa_fp8"] + quant_names = ["ORT-GQA-INT4", "ORT-GQA-INT8", "ORT-GQA-FP8"] + quant_styles = [("purple", "dotted"), ("orange", "dashdot"), ("brown", "dashed")] if config.test_int4: - line_vals.extend(quant_vals[:1]) - line_names.extend(quant_names[:1]) - styles.extend(quant_styles[:1]) + line_vals.append(quant_vals[0]) + line_names.append(quant_names[0]) + styles.append(quant_styles[0]) if config.test_int8: - line_vals.extend(quant_vals[1:]) - line_names.extend(quant_names[1:]) - styles.extend(quant_styles[1:]) + line_vals.append(quant_vals[1]) + line_names.append(quant_names[1]) + styles.append(quant_styles[1]) + if config.test_fp8: + line_vals.append(quant_vals[2]) + line_names.append(quant_names[2]) + styles.append(quant_styles[2]) return { "line_vals": line_vals, @@ -116,6 +121,9 @@ def benchmark( elif "_int8" in provider: k_quant_type = v_quant_type = "PER_TENSOR" kv_cache_type = "int8" + elif "_fp8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "fp8" config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, @@ -205,6 +213,10 @@ def benchmark( k_quant_type = v_quant_type = "PER_TENSOR" kv_cache_type = "int8" share_kv_scale = True # XQA requires shared scale + elif "_fp8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "fp8" + share_kv_scale = True # XQA requires shared scale config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, @@ -303,7 +315,7 @@ def run_performance_test( s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - config = TestConfig(test_int4=False, test_int8=True) + config = TestConfig(test_int4=False, test_int8=True, test_fp8=True) run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=True) run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=False) # run_performance_test(sm, fast=True, config=config, dtype="bfloat16", is_prompt=True) diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py index cd34f4f420ad5..7f0d50a7ac8ed 100644 --- a/onnxruntime/test/python/transformers/gqa_test_helper.py +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -21,6 +21,7 @@ "int32": TensorProto.INT32, "int8": TensorProto.INT8, "int4": TensorProto.UINT8, + "fp8": TensorProto.FLOAT8E4M3FN, } TORCH_DTYPE_TO_ONNX_MAP = { @@ -29,6 +30,7 @@ torch.bfloat16: TensorProto.BFLOAT16, torch.int32: TensorProto.INT32, torch.int8: TensorProto.INT8, + torch.float8_e4m3fn: TensorProto.FLOAT8E4M3FN, } TORCH_DTYPE_MAP = { @@ -37,6 +39,7 @@ "bfloat16": torch.bfloat16, "int8": torch.int8, "int4": torch.uint8, + "fp8": torch.float8_e4m3fn, } NUMPY_DTYPE_MAP = { @@ -45,6 +48,7 @@ "bfloat16": numpy.uint16, "int8": numpy.int8, "int4": numpy.uint8, + "fp8": numpy.uint8, # FP8 E4M3 stored as uint8 } @@ -54,6 +58,8 @@ def get_q_range(q_type_str): return -128, 127 if q_type_str.endswith("int4"): return -8, 7 + if q_type_str == "fp8": + return -448.0, 448.0 # FP8 E4M3 range raise ValueError(f"Unsupported quantization type for range: {q_type_str}") @@ -108,8 +114,14 @@ def dequantize_tensor(quantized_tensor, scale, quant_type, q_type_str): if isinstance(scale, torch.Tensor): scale = scale.to(quantized_tensor.device) - unpacked_tensor = quantized_tensor q_type_str_s = str(q_type_str) + + # FP8 dequantization: cast to float32 and multiply by scale + if q_type_str_s == "fp8": + # FP8 tensors are already float8_e4m3fn, just cast and scale + return quantized_tensor.to(torch.float32) * scale + + unpacked_tensor = quantized_tensor if q_type_str_s.endswith("int4"): unpacked_tensor = unpack_int4(quantized_tensor) @@ -121,10 +133,20 @@ def quantize_tensor_with_scale(tensor_float, scale, quant_type, q_type_str): if quant_type == "NONE": return tensor_float + q_type_str_s = str(q_type_str) + + # FP8 quantization: scale and cast to float8_e4m3fn (no rounding needed) + if q_type_str_s == "fp8": + # FP8 E4M3 has max representable value of 448.0 + # Scale the tensor and clamp to FP8 range, then cast + scaled = tensor_float / scale + clamped = torch.clamp(scaled, -448.0, 448.0) + return clamped.to(torch.float8_e4m3fn) + + # INT8/INT4 quantization: scale, round, clamp to integer range qmin, qmax = get_q_range(q_type_str) quantized = torch.clamp(torch.round(tensor_float / scale), qmin, qmax) - q_type_str_s = str(q_type_str) if q_type_str_s.endswith("int4"): quantized = pack_int4(quantized.to(torch.int8)) else: @@ -318,10 +340,17 @@ def __init__( # Quantization parameters self.k_quant_type = k_quant_type self.v_quant_type = v_quant_type - self.kv_cache_type = kv_cache_type - # Determine bit width from cache type if applicable - self.kv_cache_bit_width = 4 if kv_cache_type == "int4" else (8 if kv_cache_type == "int8" else 0) self.share_kv_scale = share_kv_scale + # Determine bit width from cache type if applicable + if kv_cache_type == "int4": + self.kv_cache_bit_width = 4 + elif kv_cache_type == "int8": + self.kv_cache_bit_width = 8 + elif kv_cache_type == "fp8": + self.kv_cache_bit_width = 8 # FP8 is 8 bits + else: + self.kv_cache_bit_width = 0 + self.kv_cache_type = kv_cache_type def shape_dict(self): shapes = super().shape_dict() @@ -450,6 +479,8 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): cache_type = TensorProto.UINT8 elif config.kv_cache_type == "int8": cache_type = TensorProto.INT8 + elif config.kv_cache_type == "fp8": + cache_type = TensorProto.FLOAT8E4M3FN # Compute actual cache shapes (packed for INT4) past_key_shape = list(shape_dict["past_key"]) diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 5cbba989a4dbd..6def1be804743 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -57,7 +57,10 @@ has_int4_kv_cache = ", int4-kv-cache=" in get_build_info() -enable_debug_print = False +has_fp8_kv_cache = ", fp8-kv-cache=" in get_build_info() + +# Enable debug print if tensor or node dumping is enabled in build. +enable_debug_print = ("dump-tensor" in get_build_info()) or ("dump-node" in get_build_info()) enable_deterministic_check = True # ################################################################################################# @@ -674,7 +677,8 @@ def gqa_past_func( k_scale = k_scale.to(torch.float32) k_scale = k_scale.contiguous() bind_tensor(io_binding, "k_scale", k_scale, device, k_scale_ort_type) - if v_scale is not None: + + if v_scale is not None and not config.share_kv_scale: v_scale_ort_type = TensorProto.FLOAT if v_scale.dtype != torch.float32: v_scale = v_scale.to(torch.float32) @@ -931,20 +935,30 @@ def parity_check_gqa_prompt( elif causal: window_size = (-1, 0) - # --- PyTorch Reference Path --- - if config.kv_cache_bit_width == 4 or config.kv_cache_type == "int8": + if config.kv_cache_bit_width == 4 or config.kv_cache_type == "int8" or config.kv_cache_type == "fp8": + # k/v are already quantized (int8/fp8) in inputs k_ref_dequant = dequantize_tensor(k, k_scale, config.k_quant_type, config.kv_cache_type) v_ref_dequant = dequantize_tensor(v, v_scale, config.v_quant_type, config.kv_cache_type) else: k_ref_dequant = dequantize_tensor( - quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type), - k_scale, + quantize_tensor_with_scale( + k, + k_scale.to(torch.float32) if k_scale is not None else None, + config.k_quant_type, + config.kv_cache_type, + ), + k_scale.to(torch.float32) if k_scale is not None else None, config.k_quant_type, config.kv_cache_type, ) v_ref_dequant = dequantize_tensor( - quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type), - v_scale, + quantize_tensor_with_scale( + v, + v_scale.to(torch.float32) if v_scale is not None else None, + config.v_quant_type, + config.kv_cache_type, + ), + v_scale.to(torch.float32) if v_scale is not None else None, config.v_quant_type, config.kv_cache_type, ) @@ -1097,6 +1111,9 @@ def parity_check_gqa_prompt( elif config.kv_cache_type == "int8": # For int8, present_k is int8 data present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + # For fp8, present_k is float8_e4m3fn data, returned as uint8/int8 by ORT python + present_k_torch = torch.from_numpy(present_k).view(torch.float8_e4m3fn).to(device) else: present_k_torch = torch.from_numpy(present_k).to(device) @@ -1134,6 +1151,8 @@ def parity_check_gqa_prompt( present_v_torch = torch.from_numpy(present_v).to(device) elif config.kv_cache_type == "int8": present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_v_torch = torch.from_numpy(present_v).view(torch.float8_e4m3fn).to(device) else: present_v_torch = torch.from_numpy(present_v).to(device) @@ -1345,8 +1364,8 @@ def parity_check_gqa_past( # Quantize k and v for ORT when using quantized KV cache k_ort = k v_ort = v - if config.kv_cache_type in ["int8", "int4"]: - # NOTE: Quantize returns tensor with kv_cache_type (int8) + if config.kv_cache_type in ["int8", "int4", "fp8"]: + # NOTE: Quantize returns tensor with kv_cache_type (int8, int4, or fp8) k_ort = quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type) v_ort = quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type) @@ -1386,26 +1405,37 @@ def parity_check_gqa_past( if numpy.count_nonzero(out_ref_np) > 0 and numpy.count_nonzero(out_np) == 0: raise RuntimeError("Output is all zeros") + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + # --- Comparison --- - if config.k_quant_type == "NONE" and config.v_quant_type == "NONE": + compare_kv = (config.k_quant_type == "NONE" and config.v_quant_type == "NONE") or (config.kv_cache_type == "fp8") + if compare_kv: # Compare KV cache # Transpose reference back to BNSH to match ORT output k_cache_ref_np = k_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() v_cache_ref_np = v_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() - present_k_np = present_k.to(torch.float32).detach().cpu().numpy() - present_v_np = present_v.to(torch.float32).detach().cpu().numpy() - if not config.share_buffer: - total_len = config.past_kv_sequence_length + config.q_sequence_length - k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] - v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + if isinstance(present_k, torch.Tensor): + present_k_torch = present_k.to(device) + present_v_torch = present_v.to(device) + else: + present_k_torch = torch.from_numpy(present_k).to(device) + present_v_torch = torch.from_numpy(present_v).to(device) + + if config.kv_cache_type == "fp8": + # FP8 cache needs dequantization for comparison with float reference + present_k_dequant = dequantize_tensor(present_k_torch, k_scale, config.k_quant_type, config.kv_cache_type) + present_v_dequant = dequantize_tensor(present_v_torch, v_scale, config.v_quant_type, config.kv_cache_type) + present_k_np = present_k_dequant.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v_dequant.to(torch.float32).detach().cpu().numpy() + else: + present_k_np = present_k_torch.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v_torch.to(torch.float32).detach().cpu().numpy() numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) - print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") - numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) - # Compare quantized cache with proper masking per batch if config.k_quant_type != "NONE": if isinstance(present_k, torch.Tensor): @@ -1415,6 +1445,8 @@ def parity_check_gqa_past( present_k_torch = torch.from_numpy(present_k).to(device) elif config.kv_cache_type == "int8": present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_k_torch = torch.from_numpy(present_k).view(torch.float8_e4m3fn).to(device) else: present_k_torch = torch.from_numpy(present_k).to(device) @@ -1455,6 +1487,8 @@ def parity_check_gqa_past( present_v_torch = torch.from_numpy(present_v).to(device) elif config.kv_cache_type == "int8": present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + elif config.kv_cache_type == "fp8": + present_v_torch = torch.from_numpy(present_v).view(torch.float8_e4m3fn).to(device) else: present_v_torch = torch.from_numpy(present_v).to(device) @@ -1851,8 +1885,14 @@ def gqa_cuda_quantized_test_cases(is_past: bool): else gqa_cuda_prompt_test_cases(allow_local=True) ) + kv_types = ["int8"] + if has_int4_kv_cache: + kv_types.append("int4") + if has_fp8_kv_cache: + kv_types.append("fp8") + for name, config in base_cases: - for kv_type in ["int8", "int4"] if has_int4_kv_cache else ["int8"]: + for kv_type in kv_types: for quant_mode in ["PER_TENSOR", "PER_CHANNEL"]: share_scales_options = [False] if quant_mode == "PER_TENSOR" and kv_type == "int8": @@ -1871,6 +1911,8 @@ def gqa_cuda_quantized_test_cases(is_past: bool): q_config.kv_cache_bit_width = 4 elif kv_type == "int8": q_config.kv_cache_bit_width = 8 + elif kv_type == "fp8": + q_config.kv_cache_bit_width = 8 q_name = f"{name}_quant_{kv_type}_{quant_mode}" if share_scales: @@ -1902,8 +1944,26 @@ def has_flash_attention(bf16=False): return True -rtol = {"fp16": 5e-3, "bf16": 5e-2, "int8_fp16": 5e-2, "int4_fp16": 5e-2, "int8_bf16": 5e-2, "int4_bf16": 5e-2} -atol = {"fp16": 5e-3, "bf16": 1e-2, "int8_fp16": 1e-1, "int4_fp16": 1e-1, "int8_bf16": 2e-1, "int4_bf16": 2e-1} +rtol = { + "fp16": 5e-3, + "bf16": 5e-2, + "int8_fp16": 5e-2, + "int4_fp16": 5e-2, + "int8_bf16": 5e-2, + "int4_bf16": 5e-2, + "fp8_fp16": 5e-2, + "fp8_bf16": 5e-2, +} +atol = { + "fp16": 5e-3, + "bf16": 1e-2, + "int8_fp16": 1e-1, + "int4_fp16": 1e-1, + "int8_bf16": 2e-1, + "int4_bf16": 2e-1, + "fp8_fp16": 1e-1, + "fp8_bf16": 2e-1, +} def has_quantized_kv_cache(): @@ -2355,6 +2415,134 @@ def test_gqa_int8_large_seq_batch4(self): atol=5e-2, ) + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_kv_cache(self): + """ + Test GQA with FP8 E4M3 quantized KV cache. + Requires SM89+ (Ada Lovelace or newer) and USE_FP8_KV_CACHE build flag. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=128, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=127, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + try: + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + except Exception as e: + # FP8 may not be built, skip if kernel not registered + if "Float8E4M3FN" in str(e) or "fp8" in str(e).lower(): + self.skipTest(f"FP8 KV cache not available: {e}") + raise + + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_prompt(self): + """ + Test GQA Prompt phase with FP8 E4M3 quantized KV cache. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=128, + q_sequence_length=128, + kv_sequence_length=128, + past_kv_sequence_length=0, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + kv_cache_bit_width=8, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + try: + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + except Exception as e: + if "Float8E4M3FN" in str(e) or "fp8" in str(e).lower(): + self.skipTest(f"FP8 KV cache not available: {e}") + raise + + @unittest.skipIf(not has_cuda_device(89) or not has_fp8_kv_cache, "FP8 KV cache is not available, skipping tests.") + def test_gqa_fp8_fallback_unsupported_head_size(self): + """ + Test GQA with FP8 KV cache on a head size not supported by XQA. + This forces fallback to the generic generic kernel (if available) or ensures graceful failure/correctness. + """ + config = GQAConfig( + batch_size=2, + num_heads=32, + kv_num_heads=8, + head_size=48, # Valid head size (multiple of 16) but not supported by XQA (supports 64, 128, 256) + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=64, + buffer_sequence_length=128, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="fp8", + share_buffer=True, + share_kv_scale=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + if __name__ == "__main__": unittest.main() From abcbdadc08e315fb5147e997eedbcb07b7757505 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 00:49:58 +0000 Subject: [PATCH 03/11] update comments --- .../contrib_ops/cuda/bert/group_query_attention_impl.cu | 2 +- .../contrib_ops/cuda/bert/group_query_attention_qdq.cuh | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index dc0fc1b32b2bc..70ed5e9bd1ace 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -858,7 +858,7 @@ Status DequantizeFlashAttentionFallback( return Status::OK(); } -// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache. +// Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache. template Status FlashAttentionAndQuantizeKV( const cudaDeviceProp& device_prop, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh index a16dc46046951..a9ffa618e403c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -// Enable quantized KV cache support for INT8/INT4 +// Enable quantized KV cache support for INT8/INT4/FP8 #define KV_QUANT_SUPPORTED 1 #include @@ -49,7 +49,7 @@ struct TypeConverter<__nv_bfloat16> { // ============================================================================ // // This file implements symmetric quantization for KV cache in GroupQueryAttention. -// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes. +// Supports INT4, INT8, and FP8 (E4M3) with PER_TENSOR and PER_CHANNEL quantization modes. // // QUANTIZATION SCHEME: // ------------------- @@ -96,7 +96,7 @@ struct TypeConverter<__nv_bfloat16> { // - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float // ============================================================================ -// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T). +// Dequantization Kernel: Converts Quantized (Int8/Int4/FP8) KV cache back to Floating Point (T). // Iterates over every individual element with one thread per element. template __global__ void DequantizeKernel(T* dequantized_data, @@ -195,7 +195,7 @@ Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data, return CUDA_CALL(cudaGetLastError()); } -// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values. +// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4/FP8) values. // Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization // or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel. template From 3a260ef88942fa0296e4417e0171562c741b0331 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 05:54:05 +0000 Subject: [PATCH 04/11] update doc --- docs/OperatorKernels.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0230f2866fcb4..41811201cbf0e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1003,7 +1003,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)
**T_KV_SCALE** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| From 9ad73e432ba2b0e9efbeb69ae20d14d1d0d5d1f1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 06:53:26 +0000 Subject: [PATCH 05/11] udpate io binding helper type mapping --- .../tools/transformers/io_binding_helper.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index e71f341e1e818..ced5f4ebec3aa 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -35,8 +35,8 @@ def get_output_type(ort_session, name: str) -> str: @staticmethod def ort_type_to_numpy_type(ort_type: str): ort_type_to_numpy_type_map = { - "tensor(int64)": numpy.longlong, - "tensor(int32)": numpy.intc, + "tensor(int64)": numpy.int64, + "tensor(int32)": numpy.int32, "tensor(float)": numpy.float32, "tensor(float16)": numpy.float16, "tensor(bool)": bool, @@ -131,8 +131,7 @@ def ort_type_to_onnx_type(ort_type: str): @staticmethod def numpy_type_to_torch_type(numpy_type: numpy.dtype): numpy_type_to_torch_type_map = { - numpy.longlong: torch.int64, - numpy.intc: torch.int32, + numpy.int64: torch.int64, numpy.int32: torch.int32, numpy.float32: torch.float32, numpy.float16: torch.float16, @@ -141,9 +140,9 @@ def numpy_type_to_torch_type(numpy_type: numpy.dtype): numpy.int8: torch.int8, numpy.float64: torch.float64, numpy.int16: torch.int16, - numpy.uint16: torch.int32, - numpy.uint32: torch.int64, - numpy.uint64: torch.int64, + numpy.uint16: torch.uint16, + numpy.uint32: torch.uint32, + numpy.uint64: torch.uint64, numpy.complex64: torch.complex64, numpy.complex128: torch.complex128, } @@ -156,8 +155,8 @@ def numpy_type_to_torch_type(numpy_type: numpy.dtype): @staticmethod def torch_type_to_numpy_type(torch_type: torch.dtype): torch_type_to_numpy_type_map = { - torch.int64: numpy.longlong, - torch.int32: numpy.intc, + torch.int64: numpy.int64, + torch.int32: numpy.int32, torch.float32: numpy.float32, torch.float16: numpy.float16, torch.bool: bool, From e83141ab4593bdd53c9162cf94d9392011972bac Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 07:01:10 +0000 Subject: [PATCH 06/11] copilot feedback --- .../contrib_ops/cuda/bert/group_query_attention_qkv.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 851cc35018e2b..6bf7c5b01e4d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -212,7 +212,7 @@ __global__ void UnpackRoPEAppend( if constexpr (std::is_same::value) { // FP8 E4M3 Quantization: scale and convert to FP8 format constexpr float kFp8E4M3Max = 448.0f; - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < elements_per_thread; ++i) { float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; float scaled_val = min(kFp8E4M3Max, max(-kFp8E4M3Max, static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))); __nv_fp8_e4m3 fp8_val = __nv_fp8_e4m3(scaled_val); @@ -222,7 +222,7 @@ __global__ void UnpackRoPEAppend( #endif { // INT8 Quantization: round and clamp to [-128, 127] - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < elements_per_thread; ++i) { float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * (sc == 0.0f ? 0.0f : 1.0f / sc))))); packed |= (static_cast(static_cast(q)) << (i * 8)); From d28d92da5713abdbaeafe1b5b313e70fe19ef172 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 07:16:19 +0000 Subject: [PATCH 07/11] update test --- onnxruntime/test/python/transformers/test_mha_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index a015ce6979f91..c66c1e96bb437 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -371,7 +371,7 @@ def parity_check_mha( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) + out_ref, _ = attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False) out_ref = out_ref.detach().cpu().numpy() numpy.testing.assert_allclose( From 2a52fc15e497adf495fb621799230990a7766d01 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 12 Feb 2026 07:21:34 +0000 Subject: [PATCH 08/11] lintrunner --- onnxruntime/test/python/transformers/test_mha_flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index c66c1e96bb437..150f8418d75ab 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -371,7 +371,9 @@ def parity_check_mha( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False) + out_ref, _ = attention_ref( + q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False + ) out_ref = out_ref.detach().cpu().numpy() numpy.testing.assert_allclose( From 068d05292d3369eaf81cd0d20272f4b926cbf997 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Feb 2026 01:30:43 +0000 Subject: [PATCH 09/11] consolidate cuda type --- .../cuda/bert/group_query_attention_impl.cu | 218 +++++--------- .../cuda/bert/group_query_attention_qdq.cuh | 57 +--- .../cuda/bert/group_query_attention_qkv.cuh | 2 +- .../core/providers/cuda/llm/attention.cc | 282 +++++++++--------- 4 files changed, 234 insertions(+), 325 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 70ed5e9bd1ace..dd0942d2b7f5c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -77,21 +77,19 @@ namespace cuda { // 3. Ensuring synchronization between past and present KV caches when necessary. // 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches. // 5. Returning strict Q, K, V pointers ready for the core attention operation. -template +template Status PrepareQKV( cudaStream_t stream, const int max_threads_per_block, const GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, - const T*& q) { + GroupQueryAttentionData& data, + const CudaT*& q) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; CudaT* q_out = reinterpret_cast(data.qkv_buffer); if (!parameters.is_packed_qkv && !parameters.do_rotary) { @@ -137,9 +135,9 @@ Status PrepareQKV( stream, max_threads_per_block))); if (q_out != nullptr) { - q = reinterpret_cast(q_out); + q = reinterpret_cast(q_out); } else { - q = reinterpret_cast(data.query); + q = reinterpret_cast(data.query); } return Status::OK(); @@ -148,16 +146,16 @@ Status PrepareQKV( ////////// Auxiliary Kernels for KV prep // Concat new to past in present. Supports past BSNH or past BNSH -template +template Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, cudaStream_t stream, const int max_threads_per_block, const bool past_only = false, - const T* cos_cache = nullptr, - const T* sin_cache = nullptr, + const CudaT* cos_cache = nullptr, + const CudaT* sin_cache = nullptr, const int rotary_dim = 0, const int64_t* position_ids = nullptr, const bool interleaved = false) { @@ -180,12 +178,12 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, is_bsnh, data.past_seq_lens, data.total_seq_lens, - reinterpret_cast(data.past_key), - reinterpret_cast(data.past_value), - reinterpret_cast(new_key), - reinterpret_cast(new_value), - reinterpret_cast(data.present_key), - reinterpret_cast(data.present_value), + reinterpret_cast(data.past_key), + reinterpret_cast(data.past_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), stream, max_threads_per_block, past_only, @@ -197,9 +195,9 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, } // Concat new to kv buffer in place -template +template Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, bool is_new_kv_bnsh_format, @@ -218,10 +216,10 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, data.past_seq_lens, data.total_seq_lens, parameters.sequence_length, - reinterpret_cast(new_key), - reinterpret_cast(new_value), - reinterpret_cast(data.present_key), - reinterpret_cast(data.present_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, @@ -254,9 +252,9 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, // - q_num_heads is divisible by kv_num_heads // - H * q_num_heads <= max_threads_per_block (use UngroupLarge otherwise) // ============================================================================ -template -__global__ void Ungroup(const T* kv_in, - T* kv_out, +template +__global__ void Ungroup(const CudaT* kv_in, + CudaT* kv_out, const int in_seqlen, const int kv_num_heads, const bool is_bsnh) { @@ -297,9 +295,9 @@ __global__ void Ungroup(const T* kv_in, // blockIdx.y = s (sequence position) // blockIdx.z = b (batch index) // ============================================================================ -template -__global__ void UngroupLarge(const T* kv_in, - T* kv_out, +template +__global__ void UngroupLarge(const CudaT* kv_in, + CudaT* kv_out, const int H, const int in_seqlen, const int q_num_heads, @@ -330,7 +328,7 @@ __global__ void UngroupLarge(const T* kv_in, } // Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. -template +template Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, float2* k_buff, float2* v_buff, const float2* k_og, const float2* v_og, @@ -406,8 +404,8 @@ Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, // One thread per element in packed_qkv. Thread determines which of Q/K/V // the element belongs to based on the offset within the hidden dimension. // ============================================================================ -template -__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, +template +__global__ void UnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unpacked_k, CudaT* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size) { const int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -421,7 +419,7 @@ __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* int offset = tid % d; if (output_bnsh) { // output BNSH int head_count = kv_num_heads; - T* unpacked = nullptr; + CudaT* unpacked = nullptr; if (offset < q_hidden) { unpacked = unpacked_q; head_count = num_heads; @@ -468,13 +466,13 @@ __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* } // Unpack packed qkv -template -Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, +template +Status LaunchUnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unpacked_k, CudaT* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block) { const int threads = max_threads_per_block; const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; - UnpackQKV<<>>( + UnpackQKV<<>>( packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size); return CUDA_CALL(cudaGetLastError()); } @@ -554,12 +552,12 @@ Status LaunchGetSequenceLengths( ////////// Kernels (supports right padding but not left padding) // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. -template +template Status ExtremeDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { ORT_GQA_TRACE("ExtremeDecoding"); @@ -573,8 +571,6 @@ Status ExtremeDecoding( // bool is_causal = parameters.is_unidirectional; // bool is_bf16 = std::is_same::value; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V @@ -652,12 +648,12 @@ Status ExtremeDecoding( // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. -template +template Status FlashDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); @@ -671,29 +667,29 @@ Status FlashDecoding( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; - void* query = reinterpret_cast(const_cast(data.query)); + void* query = reinterpret_cast(const_cast(data.query)); void* key; void* value; if (!parameters.is_packed_qkv) { - key = reinterpret_cast(const_cast(data.key)); - value = reinterpret_cast(const_cast(data.value)); + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { const size_t key_offset = static_cast(num_heads * head_size); const size_t value_offset = static_cast(kv_num_heads * head_size); - key = reinterpret_cast(query) + key_offset; - value = reinterpret_cast(key) + value_offset; + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; } void* seqlens_k = reinterpret_cast(data.past_seq_lens); void* present_key = data.present_key; void* present_value = data.present_value; - void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); - void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); - void* head_sink = reinterpret_cast(const_cast(data.head_sink)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; @@ -716,12 +712,12 @@ Status FlashDecoding( // Use extra kernel(s) for unpacking, rotary and kv append. // Flash attention is used for attention only. -template +template Status FlashAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; @@ -733,21 +729,21 @@ Status FlashAttention( AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; DUMP_TENSOR_INIT(); - const T* q_prep = nullptr; - ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); + const CudaT* q_prep = nullptr; + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); - void* query = const_cast(q_prep); + void* query = const_cast(q_prep); void* present_key = data.present_key; void* present_value = data.present_value; // Disable internal RoPE in Flash Attention (pass nullptr) void* cos_cache = nullptr; void* sin_cache = nullptr; - void* head_sink = reinterpret_cast(const_cast(data.head_sink)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); // We have already appended (and quantized if needed) the new tokens into present_key/value. // Pass nullptr for new_k/new_v to disable flash attention kernel's internal Append_KV logic. @@ -782,12 +778,12 @@ Status FlashAttention( // Fallback path for decoding quantized kv cache, when XQA is not usable (due to softcap, window, etc.) // We dequantize the cache and run standard Flash Attention. -template +template Status DequantizeFlashAttentionFallback( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { assert(!parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -796,8 +792,7 @@ Status DequantizeFlashAttentionFallback( // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer). // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized] - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; + CudaT* q_rot = reinterpret_cast(data.qkv_buffer); size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size; size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size; @@ -838,7 +833,7 @@ Status DequantizeFlashAttentionFallback( // Step 3: Run Flash Attention on dequantized k/v bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; // Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens. void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens)); @@ -847,7 +842,7 @@ Status DequantizeFlashAttentionFallback( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, q_rot, k_dequant, v_dequant, nullptr /*new K*/, nullptr /*new V*/, data.output, reinterpret_cast(data.softmax_lse), seqlens_k_ptr, nullptr /*cos_cache*/, nullptr /*sin_cache*/, - /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr, + /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr, parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.seqlen_present_kv_cache, parameters.sequence_length, 0 /*rotary_dim = 0 as it is already rotated*/, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, is_bsnh, parameters.num_splits, @@ -859,12 +854,12 @@ Status DequantizeFlashAttentionFallback( } // Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache. -template +template Status FlashAttentionAndQuantizeKV( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { assert(parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -878,12 +873,11 @@ Status FlashAttentionAndQuantizeKV( ORT_GQA_TRACE("FlashAttentionAndQuantizeKV"); - bool past_bsnh = parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_ENFORCE(parameters.past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, "GQA only supports BNSH format for KV cache."); size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; CudaT* q_final = reinterpret_cast(data.qkv_buffer); // For FlashAttentionAndQuantizeKV, we need float K and V for attention. @@ -907,7 +901,7 @@ Status FlashAttentionAndQuantizeKV( // 2. Run Float Flash Attention bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value || std::is_same::value; + bool is_bf16 = std::is_same::value; int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1; @@ -923,57 +917,17 @@ Status FlashAttentionAndQuantizeKV( local_window_size)); if (parameters.k_quant_type != KVQuantizationType::NONE) { - if (parameters.kv_cache_bit_width == 8) { -#ifdef USE_FP8_KV_CACHE - if constexpr (std::is_same::value) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast<__nv_fp8_e4m3*>(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.k_quant_type, true, past_bsnh))); - } else { -#endif - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.k_quant_type, true, past_bsnh))); -#ifdef USE_FP8_KV_CACHE - } -#endif -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 4, parameters.k_quant_type, true, past_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, true))); } if (parameters.v_quant_type != KVQuantizationType::NONE) { - if (parameters.kv_cache_bit_width == 8) { -#ifdef USE_FP8_KV_CACHE - if constexpr (std::is_same::value) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast<__nv_fp8_e4m3*>(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.v_quant_type, true, past_bsnh))); - } else { -#endif - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 8, parameters.v_quant_type, true, past_bsnh))); -#ifdef USE_FP8_KV_CACHE - } -#endif -#ifdef USE_INT4_KV_CACHE - } else if (parameters.kv_cache_bit_width == 4) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, - nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, - head_size, 4, parameters.v_quant_type, true, past_bsnh))); -#endif - } + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, true))); } return Status::OK(); @@ -981,12 +935,12 @@ Status FlashAttentionAndQuantizeKV( #endif #if USE_MEMORY_EFFICIENT_ATTENTION -template +template Status EfficientAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; @@ -999,8 +953,8 @@ Status EfficientAttention( ORT_GQA_TRACE("EfficientAttention"); - const T* q_prep = nullptr; - ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); + const CudaT* q_prep = nullptr; + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); const void* query = reinterpret_cast(q_prep); const void* key; @@ -1017,16 +971,16 @@ Status EfficientAttention( const float2* k_og = reinterpret_cast(data.present_key); const float2* v_og = reinterpret_cast(data.present_value); - ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, - present_sequence_length, is_kv_bsnh, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_kv_bsnh, stream, max_threads_per_block)); key = reinterpret_cast(data.k); value = reinterpret_cast(data.v); } MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; - p.is_bf16 = std::is_same::value || std::is_same::value; - p.is_half = !p.is_bf16 && (sizeof(T) == 2); + p.is_bf16 = std::is_same::value; + p.is_half = !p.is_bf16 && (sizeof(CudaT) == 2); p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; @@ -1046,7 +1000,7 @@ Status EfficientAttention( p.attn_bias = nullptr; p.is_kv_bsnh = is_kv_bsnh; p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(CudaT) == sizeof(float)) ? data.fmha_buffer : nullptr; p.stream = stream; @@ -1061,13 +1015,13 @@ Status EfficientAttention( ////////// API Functions -template +template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& /*cublas*/, Stream* ort_stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { + GroupQueryAttentionData& data) { auto stream = static_cast(ort_stream->GetHandle()); const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; if (data.use_xqa) { @@ -1103,7 +1057,6 @@ Status QkvToContext( template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; -template struct GroupQueryAttentionData; template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1120,13 +1073,6 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data); -template Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, - contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); - template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh index a9ffa618e403c..b69b0238686a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh @@ -207,8 +207,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, int input_sequence_length, int cache_sequence_length, int num_heads, int head_size, int bit_width, KVQuantizationType quant_type, - bool is_input_bsnh, - bool is_output_bsnh) { + bool is_input_bsnh) { // elements_per_head_packed is the number of BYTES occupied by head_size elements. int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size; @@ -234,30 +233,12 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, if (s >= total_valid_len_b) { if (bit_width == 8) { int64_t out_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } + reinterpret_cast(quantized_data)[out_idx] = 0; #ifdef USE_FP8_KV_CACHE } else if constexpr (std::is_same::value) { // FP8 int64_t out_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } + reinterpret_cast<__nv_fp8_e4m3*>(quantized_data)[out_idx] = __nv_fp8_e4m3(0.0f); #endif #ifdef USE_INT4_KV_CACHE @@ -271,16 +252,7 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, // Since `h_idx` comes from `i % elements_per_head_packed`, `out_idx` is guaranteed // to be within the buffer bounds. Writing kInt4ZeroPacked is safe. int64_t out_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } + // INT4 uses +8 bias, so zero values pack to 0x88 reinterpret_cast(quantized_data)[out_idx] = kInt4ZeroPacked; #endif @@ -289,16 +261,6 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, } int64_t output_idx = i; - if (is_output_bsnh) { - int64_t b_idx = b; - int64_t n_idx = n; - int64_t s_idx = s; - int64_t h_idx = i % elements_per_head_packed; - output_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + - s_idx * num_heads * elements_per_head_packed + - n_idx * elements_per_head_packed + - h_idx; - } #ifdef USE_FP8_KV_CACHE if constexpr (std::is_same::value) { @@ -312,9 +274,9 @@ __global__ void QuantizeKernel(T_QUANT* quantized_data, } float inv_scale = (scale_val == 0.0f) ? 0.0f : 1.0f / scale_val; - int64_t flattened_input_idx = is_input_bsnh ? ((int64_t)b * input_sequence_length * num_heads * head_size + - (int64_t)s * num_heads * head_size + - (int64_t)n * head_size + + int64_t flattened_input_idx = is_input_bsnh ? (static_cast(b) * input_sequence_length * num_heads * head_size + + static_cast(s) * num_heads * head_size + + static_cast(n) * head_size + h) : ((int64_t)b * num_heads * input_sequence_length * head_size + (int64_t)n * input_sequence_length * head_size + @@ -419,8 +381,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data, int batch_size, int num_heads, int input_sequence_length, int cache_sequence_length, int head_size, int bit_width, KVQuantizationType quant_type, - bool is_input_bsnh, - bool is_output_bsnh) { + bool is_input_bsnh) { assert(total_seq_lens != nullptr); if (cache_sequence_length == 0) return Status::OK(); @@ -431,7 +392,7 @@ Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data, QuantizeKernel<<>>( quantized_data, dequantized_data, scale, past_seq_lens, total_seq_lens, total_packed_elements, - input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh, is_output_bsnh); + input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 6bf7c5b01e4d3..327e031414299 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -237,7 +237,7 @@ __global__ void UnpackRoPEAppend( constexpr float kInt4Max = 7.0f; const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; uint32_t packed = 0; - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < elements_per_thread / 2; ++i) { // Elements are paired as (0,1), (2,3), etc. into single bytes. float s0 = per_channel ? scale_buffer[n * head_size + h + i * 2] : scale_buffer[0]; float s1 = per_channel ? scale_buffer[n * head_size + h + i * 2 + 1] : scale_buffer[0]; diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 6c235f95aabcf..a97a09a040d5d 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -11,6 +11,7 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; @@ -116,8 +117,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { contribop_parameters.is_output_bnsh = false; } - typedef typename ToCudaType::MappedType CudaT; - // Check if this is Group Query Attention (GQA) const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; @@ -196,6 +195,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { gqa_parameters.num_splits = 1; // Construct GroupQueryAttentionData + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; // Scratch buffers for flash/memory efficient attention @@ -480,152 +480,154 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = GetCublasHandle(context); return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data); - } + device_prop, cublas, context->GetComputeStream(), gqa_parameters, + gqa_data); + } else { // MHA path (kv_num_heads == q_num_heads) + typedef typename ToCudaType::MappedType CudaT; + contribop_parameters.batch_size = parameters.batch_size; + contribop_parameters.sequence_length = parameters.q_sequence_length; + contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; + contribop_parameters.past_sequence_length = parameters.past_sequence_length; + contribop_parameters.total_sequence_length = parameters.total_sequence_length; + // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) + contribop_parameters.max_sequence_length = parameters.total_sequence_length; + contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V + contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; + contribop_parameters.head_size = parameters.head_size; + contribop_parameters.v_head_size = parameters.v_head_size; + contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + contribop_parameters.num_heads = parameters.q_num_heads; + contribop_parameters.rotary_dim = 0; + contribop_parameters.num_splits = 1; + contribop_parameters.beam_width = 1; + contribop_parameters.is_unidirectional = parameters.is_causal; + contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer + contribop_parameters.is_packed_qkv = false; + contribop_parameters.do_rotary = false; + + // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask + // So mask_type should always be MASK_NONE since we don't have a separate padding mask input + contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; + + // Determine broadcast flags for attention_bias (if it exists) + // Note: The new Attention op uses attn_mask as attention_bias + // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) + // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). + if (attn_mask != nullptr) { + // TODO(titaiwang, xadupre): attn_mask bool is not supported yet + if (attn_mask->IsDataType()) { + ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); + } - // MHA path (kv_num_heads == q_num_heads) - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - - // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask - // So mask_type should always be MASK_NONE since we don't have a separate padding mask input - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - - // Determine broadcast flags for attention_bias (if it exists) - // Note: The new Attention op uses attn_mask as attention_bias - // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) - // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). - if (attn_mask != nullptr) { - // TODO(titaiwang, xadupre): attn_mask bool is not supported yet - if (attn_mask->IsDataType()) { - ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); + size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); + auto attn_mask_dims = attn_mask->Shape().GetDims(); + // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting + // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting + // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 + + if (attn_mask_dims_size == 2) { + // 2D mask: both dimensions need broadcasting + contribop_parameters.broadcast_attn_bias_dim_0 = true; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else if (attn_mask_dims_size == 3) { + // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else { + // 4D mask: check both dim 0 and dim 1 explicitly + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; + } + } else { + contribop_parameters.broadcast_attn_bias_dim_0 = false; + contribop_parameters.broadcast_attn_bias_dim_1 = false; } - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting - // For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting - // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 - - if (attn_mask_dims_size == 2) { - // 2D mask: both dimensions need broadcasting - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - // 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else { - // 4D mask: check both dim 0 and dim 1 explicitly - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; + contribop_parameters.mask_filter_value = -10000.0f; + contribop_parameters.scale = parameters.scale; + contribop_parameters.use_tf32 = UseTF32(); + // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now + if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && + qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { + ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); + } + // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet + if (parameters.softcap != 0.0f) { + ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); + } + if (parameters.softmax_precision != 0) { + ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; - } - contribop_parameters.mask_filter_value = -10000.0f; - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now - if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && - qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { - ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); - } - // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet - if (parameters.softcap != 0.0f) { - ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); - } - if (parameters.softmax_precision != 0) { - ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); - } + // Construct AttentionData to pass to QkvToContext + onnxruntime::contrib::cuda::AttentionData data; - // Construct AttentionData to pass to QkvToContext - onnxruntime::contrib::cuda::AttentionData data; - - // Set input pointers - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - // Set output pointers - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (nullptr != output_qk) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); - } + // Set input pointers + data.query = reinterpret_cast(Q->Data()); + data.key = reinterpret_cast(K->Data()); + data.value = reinterpret_cast(V->Data()); + data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask + data.mask_index_dims = gsl::span(); + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + + // Set output pointers + data.output = reinterpret_cast(Y->MutableData()); + data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); + if (nullptr != output_qk) { + data.output_qk = reinterpret_cast(output_qk->MutableData()); + } - // Set additional fields - data.bias = nullptr; // New Attention op doesn't have bias - if (nullptr != attn_mask) { - data.attention_bias = reinterpret_cast(attn_mask->Data()); + // Set additional fields + data.bias = nullptr; // New Attention op doesn't have bias + if (nullptr != attn_mask) { + data.attention_bias = reinterpret_cast(attn_mask->Data()); + } + data.qkv_format = contribop_parameters.qkv_format; + + // For now, set flags to false and let QkvToContext use the unfused path + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; + data.fused_runner = nullptr; + data.fused_cross_attention_kernel = nullptr; + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; + + // Allocate workspace for Q, K, V processing and scratch buffer + const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); + size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( + sizeof(T), + contribop_parameters.batch_size, + contribop_parameters.num_heads, + contribop_parameters.head_size, + contribop_parameters.v_head_size, + contribop_parameters.sequence_length, + contribop_parameters.kv_sequence_length, + contribop_parameters.total_sequence_length, + nullptr, // fused_runner + false, // use_flash_attention + false, // use_lean_attention + false, // use_fused_cross_attention + false, // use_memory_efficient_attention + false, // use_cudnn_flash_attention + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + // Call QkvToContext to perform the attention computation + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); + + // QkvToContext takes two template parameters: T for computation type, QK for output_qk type + // For now, both are the same type (CudaT) + + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); } - data.qkv_format = contribop_parameters.qkv_format; - - // For now, set flags to false and let QkvToContext use the unfused path - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace for Q, K, V processing and scratch buffer - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, // fused_runner - false, // use_flash_attention - false, // use_lean_attention - false, // use_fused_cross_attention - false, // use_memory_efficient_attention - false, // use_cudnn_flash_attention - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - // Call QkvToContext to perform the attention computation - auto& device_prop = GetDeviceProp(); - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); - - // QkvToContext takes two template parameters: T for computation type, QK for output_qk type - // For now, both are the same type (CudaT) - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); } } // namespace cuda } // namespace onnxruntime From 3d507bd07d6308f6d34fc2680dee78ab7cd28ef3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Feb 2026 22:31:47 +0000 Subject: [PATCH 10/11] refine --- .../cuda/bert/group_query_attention_impl.cu | 279 +++++++++--------- .../cuda/bert/group_query_attention_impl.h | 12 - .../cuda/bert/group_query_attention_qkv.cuh | 7 +- 3 files changed, 150 insertions(+), 148 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 99e744a345daa..cde185de907d4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -77,31 +77,34 @@ namespace cuda { // 3. Ensuring synchronization between past and present KV caches when necessary. // 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches. // 5. Returning strict Q, K, V pointers ready for the core attention operation. -template +template Status PrepareQKV( cudaStream_t stream, const int max_threads_per_block, const GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, - const CudaT*& q) { + GroupQueryAttentionData& data, + const T*& q) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - CudaT* q_out = reinterpret_cast(data.qkv_buffer); + T* q_out = reinterpret_cast(data.qkv_buffer); if (!parameters.is_packed_qkv && !parameters.do_rotary) { q_out = nullptr; } - CudaU* k = reinterpret_cast(data.present_key); - CudaU* v = reinterpret_cast(data.present_value); + U* k = reinterpret_cast(data.present_key); + U* v = reinterpret_cast(data.present_value); int max_cache_length = parameters.seqlen_present_kv_cache; if (!parameters.past_present_share_buffer) { - size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU); + size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(U); CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream)); CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); } @@ -111,8 +114,8 @@ Status PrepareQKV( // Copy past KV to present KV if needed if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU); - size_t dst_pitch = (size_t)max_cache_length * head_size * sizeof(CudaU); + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(U); + size_t dst_pitch = (size_t)max_cache_length * head_size * sizeof(U); size_t width = src_pitch; size_t height = (size_t)batch_size * kv_num_heads; CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, @@ -121,23 +124,23 @@ Status PrepareQKV( cudaMemcpyDeviceToDevice, stream)); } - ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_out, k, v, data.k_scale, data.v_scale, num_heads, kv_num_heads, head_size, sequence_length, batch_size, max_cache_length, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, is_cache_bnsh, parameters.k_quant_type, stream, max_threads_per_block))); if (q_out != nullptr) { - q = reinterpret_cast(q_out); + q = reinterpret_cast(q_out); } else { - q = reinterpret_cast(data.query); + q = reinterpret_cast(data.query); } return Status::OK(); @@ -146,16 +149,16 @@ Status PrepareQKV( ////////// Auxiliary Kernels for KV prep // Concat new to past in present. Supports past BSNH or past BNSH -template +template Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, cudaStream_t stream, const int max_threads_per_block, const bool past_only = false, - const CudaT* cos_cache = nullptr, - const CudaT* sin_cache = nullptr, + const T* cos_cache = nullptr, + const T* sin_cache = nullptr, const int rotary_dim = 0, const int64_t* position_ids = nullptr, const bool interleaved = false) { @@ -178,12 +181,12 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, is_bsnh, data.past_seq_lens, data.total_seq_lens, - reinterpret_cast(data.past_key), - reinterpret_cast(data.past_value), - reinterpret_cast(new_key), - reinterpret_cast(new_value), - reinterpret_cast(data.present_key), - reinterpret_cast(data.present_value), + reinterpret_cast(data.past_key), + reinterpret_cast(data.past_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), stream, max_threads_per_block, past_only, @@ -195,9 +198,9 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, } // Concat new to kv buffer in place -template +template Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, bool is_new_kv_bnsh_format, @@ -216,10 +219,10 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, data.past_seq_lens, data.total_seq_lens, parameters.sequence_length, - reinterpret_cast(new_key), - reinterpret_cast(new_value), - reinterpret_cast(data.present_key), - reinterpret_cast(data.present_value), + reinterpret_cast(new_key), + reinterpret_cast(new_value), + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, @@ -252,9 +255,9 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, // - q_num_heads is divisible by kv_num_heads // - H * q_num_heads <= max_threads_per_block (use UngroupLarge otherwise) // ============================================================================ -template -__global__ void Ungroup(const CudaT* kv_in, - CudaT* kv_out, +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, const int in_seqlen, const int kv_num_heads, const bool is_bsnh) { @@ -295,9 +298,9 @@ __global__ void Ungroup(const CudaT* kv_in, // blockIdx.y = s (sequence position) // blockIdx.z = b (batch index) // ============================================================================ -template -__global__ void UngroupLarge(const CudaT* kv_in, - CudaT* kv_out, +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, const int H, const int in_seqlen, const int q_num_heads, @@ -328,7 +331,7 @@ __global__ void UngroupLarge(const CudaT* kv_in, } // Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. -template +template Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, float2* k_buff, float2* v_buff, const float2* k_og, const float2* v_og, @@ -404,8 +407,8 @@ Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, // One thread per element in packed_qkv. Thread determines which of Q/K/V // the element belongs to based on the offset within the hidden dimension. // ============================================================================ -template -__global__ void UnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unpacked_k, CudaT* unpacked_v, const int num_heads, +template +__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size) { const int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -419,7 +422,7 @@ __global__ void UnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unp int offset = tid % d; if (output_bnsh) { // output BNSH int head_count = kv_num_heads; - CudaT* unpacked = nullptr; + T* unpacked = nullptr; if (offset < q_hidden) { unpacked = unpacked_q; head_count = num_heads; @@ -466,13 +469,13 @@ __global__ void UnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unp } // Unpack packed qkv -template -Status LaunchUnpackQKV(const CudaT* packed_qkv, CudaT* unpacked_q, CudaT* unpacked_k, CudaT* unpacked_v, const int num_heads, +template +Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block) { const int threads = max_threads_per_block; const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; - UnpackQKV<<>>( + UnpackQKV<<>>( packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size); return CUDA_CALL(cudaGetLastError()); } @@ -552,13 +555,16 @@ Status LaunchGetSequenceLengths( ////////// Kernels (supports right padding but not left padding) // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. -template +template Status ExtremeDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + ORT_GQA_TRACE("ExtremeDecoding"); const int batch_size = parameters.batch_size; @@ -575,20 +581,20 @@ Status ExtremeDecoding( // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V // This replaces all manual steps (Rotate Q, Rotate K, Quantize, StridedCopy) - CudaT* q_rot_ptr = reinterpret_cast(data.qkv_buffer); - const CudaT* q_input_for_xqa = q_rot_ptr; + T* q_rot_ptr = reinterpret_cast(data.qkv_buffer); + const T* q_input_for_xqa = q_rot_ptr; if (q_rot_ptr == nullptr) { - q_input_for_xqa = reinterpret_cast(data.query); + q_input_for_xqa = reinterpret_cast(data.query); } - ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_rot_ptr, // unpacked_q (can be null if !do_rotary) - reinterpret_cast(data.present_key), - reinterpret_cast(data.present_value), + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), data.k_scale, data.v_scale, num_heads, @@ -598,8 +604,8 @@ Status ExtremeDecoding( batch_size, parameters.seqlen_present_kv_cache, // max_seqlen (capacity) data.past_seq_lens, - reinterpret_cast(data.cos_cache), - reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), + reinterpret_cast(data.sin_cache), parameters.do_rotary ? parameters.rotary_dim : 0, data.position_ids, parameters.rotary_interleaved, @@ -612,10 +618,10 @@ Status ExtremeDecoding( void* xqa_workspace = data.xqa_buffer; size_t xqa_workspace_size = data.xqa_buffer_bytes; - constexpr bool is_fp8 = std::is_same::value; + constexpr bool is_fp8 = std::is_same::value; using onnxruntime::contrib::cuda::XqaQuantType; // 5. Launch XQA - Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( + Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( device_prop, stream, q_input_for_xqa, @@ -648,13 +654,15 @@ Status ExtremeDecoding( // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. -template +template Status FlashDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); ORT_GQA_TRACE("FlashDecoding"); @@ -667,29 +675,29 @@ Status FlashDecoding( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value; - void* query = reinterpret_cast(const_cast(data.query)); + void* query = reinterpret_cast(const_cast(data.query)); void* key; void* value; if (!parameters.is_packed_qkv) { - key = reinterpret_cast(const_cast(data.key)); - value = reinterpret_cast(const_cast(data.value)); + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { const size_t key_offset = static_cast(num_heads * head_size); const size_t value_offset = static_cast(kv_num_heads * head_size); - key = reinterpret_cast(query) + key_offset; - value = reinterpret_cast(key) + value_offset; + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; } void* seqlens_k = reinterpret_cast(data.past_seq_lens); void* present_key = data.present_key; void* present_value = data.present_value; - void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); - void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); - void* head_sink = reinterpret_cast(const_cast(data.head_sink)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; @@ -712,13 +720,16 @@ Status FlashDecoding( // Use extra kernel(s) for unpacking, rotary and kv append. // Flash attention is used for attention only. -template +template Status FlashAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -729,21 +740,21 @@ Status FlashAttention( AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value; DUMP_TENSOR_INIT(); - const CudaT* q_prep = nullptr; - ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); + const T* q_prep = nullptr; + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); - void* query = const_cast(q_prep); + void* query = const_cast(q_prep); void* present_key = data.present_key; void* present_value = data.present_value; // Disable internal RoPE in Flash Attention (pass nullptr) void* cos_cache = nullptr; void* sin_cache = nullptr; - void* head_sink = reinterpret_cast(const_cast(data.head_sink)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); // We have already appended (and quantized if needed) the new tokens into present_key/value. // Pass nullptr for new_k/new_v to disable flash attention kernel's internal Append_KV logic. @@ -778,13 +789,16 @@ Status FlashAttention( // Fallback path for decoding quantized kv cache, when XQA is not usable (due to softcap, window, etc.) // We dequantize the cache and run standard Flash Attention. -template +template Status DequantizeFlashAttentionFallback( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + assert(!parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -793,25 +807,25 @@ Status DequantizeFlashAttentionFallback( // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer). // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized] - CudaT* q_rot = reinterpret_cast(data.qkv_buffer); + T* q_rot = reinterpret_cast(data.qkv_buffer); size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size; size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size; - CudaT* k_dequant = q_rot + q_elements; - CudaT* v_dequant = k_dequant + k_elements; + T* k_dequant = q_rot + q_elements; + T* v_dequant = k_dequant + k_elements; // Step 1: Update Quantized Cache // We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache. // This will also put rotated Q into q_rot. - ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), - q_rot, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_rot, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), data.k_scale, data.v_scale, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size, parameters.seqlen_present_kv_cache, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), parameters.k_quant_type, @@ -821,19 +835,19 @@ Status DequantizeFlashAttentionFallback( // We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant. bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, is_bsnh))); - ORT_RETURN_IF_ERROR((LaunchDequantizeKV( - stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, is_bsnh))); // Step 3: Run Flash Attention on dequantized k/v bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value; // Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens. void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens)); @@ -842,7 +856,7 @@ Status DequantizeFlashAttentionFallback( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, q_rot, k_dequant, v_dequant, nullptr /*new K*/, nullptr /*new V*/, data.output, reinterpret_cast(data.softmax_lse), seqlens_k_ptr, nullptr /*cos_cache*/, nullptr /*sin_cache*/, - /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr, + /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr, parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.seqlen_present_kv_cache, parameters.sequence_length, 0 /*rotary_dim = 0 as it is already rotated*/, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, is_bsnh, parameters.num_splits, @@ -854,13 +868,15 @@ Status DequantizeFlashAttentionFallback( } // Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache. -template +template Status FlashAttentionAndQuantizeKV( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); assert(parameters.is_first_prompt); // Only support first prompt for this function. assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); @@ -878,22 +894,22 @@ Status FlashAttentionAndQuantizeKV( size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - CudaT* q_final = reinterpret_cast(data.qkv_buffer); + T* q_final = reinterpret_cast(data.qkv_buffer); // For FlashAttentionAndQuantizeKV, we need float K and V for attention. // We'll write them to qkv_buffer. - CudaT* k_final = q_final + q_elements; - CudaT* v_final = k_final + k_elements; - - ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( - parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), - parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + T* k_final = q_final + q_elements; + T* v_final = k_final + k_elements; + + ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), q_final, k_final, v_final, nullptr, nullptr, num_heads, kv_num_heads, head_size, sequence_length, batch_size, sequence_length, data.past_seq_lens, - reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, false, // BSNH for scratch KVQuantizationType::NONE, @@ -901,7 +917,7 @@ Status FlashAttentionAndQuantizeKV( // 2. Run Float Flash Attention bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value; int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1; @@ -917,15 +933,15 @@ Status FlashAttentionAndQuantizeKV( local_window_size)); if (parameters.k_quant_type != KVQuantizationType::NONE) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, head_size, parameters.kv_cache_bit_width, parameters.k_quant_type, true))); } if (parameters.v_quant_type != KVQuantizationType::NONE) { - ORT_RETURN_IF_ERROR((LaunchQuantizeKV( - stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, head_size, parameters.kv_cache_bit_width, parameters.v_quant_type, true))); } @@ -935,13 +951,16 @@ Status FlashAttentionAndQuantizeKV( #endif #if USE_MEMORY_EFFICIENT_ATTENTION -template +template Status EfficientAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -953,8 +972,8 @@ Status EfficientAttention( ORT_GQA_TRACE("EfficientAttention"); - const CudaT* q_prep = nullptr; - ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); + const T* q_prep = nullptr; + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); const void* query = reinterpret_cast(q_prep); const void* key; @@ -971,16 +990,16 @@ Status EfficientAttention( const float2* k_og = reinterpret_cast(data.present_key); const float2* v_og = reinterpret_cast(data.present_value); - ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, - present_sequence_length, is_kv_bsnh, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_kv_bsnh, stream, max_threads_per_block)); key = reinterpret_cast(data.k); value = reinterpret_cast(data.v); } MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; - p.is_bf16 = std::is_same::value; - p.is_half = !p.is_bf16 && (sizeof(CudaT) == 2); + p.is_bf16 = std::is_same::value; + p.is_half = !p.is_bf16 && (sizeof(T) == 2); p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; @@ -1000,7 +1019,7 @@ Status EfficientAttention( p.attn_bias = nullptr; p.is_kv_bsnh = is_kv_bsnh; p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(CudaT) == sizeof(float)) + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) ? data.fmha_buffer : nullptr; p.stream = stream; @@ -1015,13 +1034,13 @@ Status EfficientAttention( ////////// API Functions -template +template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& /*cublas*/, Stream* ort_stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { + GroupQueryAttentionData& data) { auto stream = static_cast(ort_stream->GetHandle()); const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; if (data.use_xqa) { @@ -1057,7 +1076,6 @@ Status QkvToContext( template struct GroupQueryAttentionData; template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; -template struct GroupQueryAttentionData; template struct GroupQueryAttentionData; template Status QkvToContext( @@ -1074,13 +1092,6 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data); -template Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, - contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); - template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 8cd4b44b9832e..78b061837e402 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -133,18 +133,6 @@ Status LaunchGetSequenceLengths( cudaStream_t stream, const int max_threads_per_block); -template -Status LaunchUnpackRoPEAppend( - const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, void* k_cache, void* v_cache, - const float* k_scale, const float* v_scale, - const int num_heads, const int kv_num_heads, const int head_size, - const int sequence_length, const int batch_size, const int max_seqlen, - const int* past_seq_lens, const T* cos_cache, const T* sin_cache, - const int rotary_dim, const int64_t* position_ids, const bool interleaved, - const bool is_cache_bnsh, const KVQuantizationType k_quant_type, - const int bit_width, cudaStream_t stream, const int max_threads_per_block); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 327e031414299..02368dad4d193 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -32,7 +32,7 @@ namespace cuda { // 4. Writes the rotated Q back to global memory (unpacked_q) for the subsequent attention kernel. // // Template Parameters: -// - T: The floating point type for query (half or BFloat16). +// - T: The floating point type for query (half or __nv_bfloat16). // - U: The cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8). // - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8/FP8, 4=Int4). // - MAX_HEAD_SIZE: Maximum supported head size, used for shared memory allocation. @@ -291,7 +291,7 @@ Status DispatchUnpackRoPEAppendHeadSize( // Public entry point to launch the Unpack+RoPE+Append kernel. // Handles parameter validation, grid/block sizing, and type-based dispatching. // Template parameters: -// - T: Query/Key/Value floating point type (half or BFloat16) +// - T: Query/Key/Value floating point type (half or __nv_bfloat16) // - U: Cache element type (T for no quant, int8_t for INT8, uint8_t for INT4, __nv_fp8_e4m3 for FP8) template Status LaunchUnpackRoPEAppend( @@ -304,6 +304,9 @@ Status LaunchUnpackRoPEAppend( const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, cudaStream_t stream, const int max_threads_per_block) { + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + constexpr int elements_per_vector = sizeof(float4) / sizeof(T); if (head_size % elements_per_vector != 0) { From 2a25780260c23e579e4d3f04a04a5cff44f7babe Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 14 Feb 2026 04:11:54 +0000 Subject: [PATCH 11/11] fix build --- .../contrib_ops/cuda/bert/group_query_attention_impl.cu | 1 - .../contrib_ops/cuda/bert/group_query_attention_qkv.cuh | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index cde185de907d4..961c80748d228 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -50,7 +50,6 @@ limitations under the License. #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_type_conversion.h" - #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 02368dad4d193..20f0144c335ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -12,6 +12,7 @@ #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/rotary_common.cuh" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" using namespace onnxruntime::cuda; @@ -304,8 +305,8 @@ Status LaunchUnpackRoPEAppend( const int rotary_dim, const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const KVQuantizationType k_quant_type, cudaStream_t stream, const int max_threads_per_block) { - static_assert(std::is_same::type>::value); - static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); + static_assert(std::is_same::type>::value); constexpr int elements_per_vector = sizeof(float4) / sizeof(T);