From 080db98920c05d1160e6e4ccad485e7e2dbbeea1 Mon Sep 17 00:00:00 2001 From: chraac Date: Sun, 1 Feb 2026 22:00:45 +0800 Subject: [PATCH 1/7] ggml-hexagon: enhance hvx_mad functions for improved performance and clarity --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 67 +++++++++++++++++++--- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c1846374437..924097af4f9 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -222,7 +222,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } -// MAD: y (F32) += x (F16) * s (float) +// MAD: y (F32) += x (F16) * s (F32) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -259,6 +259,59 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict } } +// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, + const void * restrict x0, + const void * restrict x1, + float s0, + float s1, + int n) { + const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S0 = hvx_vec_splat_f16(s0); + HVX_Vector S1 = hvx_vec_splat_f16(s1); + + uint32_t i = 0; + #pragma unroll(2) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + + ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); + ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + } + + if (nloe) { + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs = xs_p_lo; + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; + ++i; + xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + } + } +} + #define FLASH_ATTN_BLOCK_SIZE 128 static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { @@ -415,6 +468,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + const HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; @@ -461,7 +515,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); @@ -498,12 +551,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // 5. Accumulate V float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; + *(HVX_Vector *) p_arr = P; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + for (int j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + size_v_row_padded, p_arr[j], p_arr[j + 1], DV); } } From b022069260e785f6d4781b047ec56dd18e76adf6 Mon Sep 17 00:00:00 2001 From: chraac Date: Sun, 1 Feb 2026 22:16:50 +0800 Subject: [PATCH 2/7] wip --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 924097af4f9..a3797bdc1a9 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -468,7 +468,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; - const HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; @@ -482,8 +482,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); - // Process in blocks of 32 (VLEN_FP32) static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; @@ -515,6 +513,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; + HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); From fe1f3fbc2a1e04f2afca832c5f994e370cccbccc Mon Sep 17 00:00:00 2001 From: chraac Date: Mon, 2 Feb 2026 23:59:54 +0800 Subject: [PATCH 3/7] ggml-hexagon: optimize flash attention calculations with improved variable handling --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index a3797bdc1a9..441f619698b 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -409,6 +409,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + const HVX_Vector logit_cap = hvx_vec_splat_f32(logit_softcap); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); @@ -468,7 +471,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; @@ -504,7 +507,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // 2. Softcap if (logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); scores = Q6_Vsf_equals_Vqf32(scores); } @@ -518,7 +521,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - HVX_Vector slope_vec = hvx_vec_splat_f32(slope); HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); scores = Q6_Vsf_equals_Vqf32(scores); From c2fe8a12bb168367470f7e7e2c90ec239efc23af Mon Sep 17 00:00:00 2001 From: chraac Date: Tue, 3 Feb 2026 00:14:51 +0800 Subject: [PATCH 4/7] ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32 --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 441f619698b..af32c1e2c7b 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -470,9 +470,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; - const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } + const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -495,11 +498,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (int j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (is_q_fp32) { - hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } else { - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } HVX_Vector scores = *(HVX_Vector *) scores_arr; @@ -569,13 +568,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (; ic < current_block_size; ++ic) { float s_val; const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (is_q_fp32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); if (logit_softcap != 0.0f) { s_val = logit_softcap * tanhf(s_val); } From 367463cf0348fdceaef28da229852cc29fce30d5 Mon Sep 17 00:00:00 2001 From: chraac Date: Tue, 3 Feb 2026 00:17:33 +0800 Subject: [PATCH 5/7] wip --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 115 --------------------- 1 file changed, 115 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index af32c1e2c7b..fc7ecb90414 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,121 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" -static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x0_hf = Q6_V_vand_QV(bmask, x0_hf); - x1_hf = Q6_V_vand_QV(bmask, x1_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); -} - // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 From 5458f41b18338fef93a9b8eb052ad0c467cc5209 Mon Sep 17 00:00:00 2001 From: chraac Date: Tue, 3 Feb 2026 00:37:28 +0800 Subject: [PATCH 6/7] ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index fc7ecb90414..4436b99db85 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -89,12 +89,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, } if (nloe) { - HVX_Vector y_hf = vy[i]; - - // Load x (fp16) and zero-out unused elements + // Load x (fp16) and zero-out unused y elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); From 7f477c43a42775474295d18da77773daa8af91d0 Mon Sep 17 00:00:00 2001 From: chraac Date: Tue, 3 Feb 2026 00:50:41 +0800 Subject: [PATCH 7/7] wip --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 4436b99db85..15b40f77f5f 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -405,7 +405,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); - scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); scores = Q6_Vsf_equals_Vqf32(scores); }