From cf53dc49558190d2db82bef9f7eac84ea7a52d76 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 3 Feb 2026 22:21:19 -0800 Subject: [PATCH 1/3] hexagon: add ARGSORT op --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 34 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 2 + ggml/src/ggml-hexagon/htp/argsort-ops.c | 281 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-msg.h | 17 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 2 - ggml/src/ggml-hexagon/htp/main.c | 47 ++++ 8 files changed, 372 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/argsort-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4f0a1620..2f92c761 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2111,6 +2111,21 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2316,6 +2331,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_ARGSORT; + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2564,6 +2590,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_ARGSORT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2968,6 +2998,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index e8ef2030..5922dcc8 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -28,6 +29,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + argsort-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 00000000..a4cee980 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f49e8ee4..6af61782 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -64,6 +64,7 @@ enum htp_op { HTP_OP_SCALE = 16, HTP_OP_GET_ROWS = 17, HTP_OP_CPY = 18, + HTP_OP_ARGSORT = 19, INVALID }; @@ -103,22 +104,6 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 0; } -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 602a2775..e1725151 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -98,5 +98,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ffa6e18e..12a1b7f1 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) { return x; } +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 6b617b76..ae0dbed0 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const HVX_Vector zero = Q6_V_vsplat_R(0); \ - \ const uint32_t elem_size = sizeof(__fp16); \ const uint32_t epv = 128 / elem_size; \ const uint32_t nvec = n / epv; \ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e28a67a9..0ac7c557 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -1035,6 +1074,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; From bdcb213c87128f29e14a97eb1936572cd754a221 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 4 Feb 2026 18:22:08 -0800 Subject: [PATCH 2/3] hexagon: argsort reject tensors with huge rows for now --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 2f92c761..2583b338 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2123,6 +2123,11 @@ static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * s return false; } + if (src0->ne[0] > (16*1024)) { + // reject tensors with huge rows for now + return false; + } + return true; } From 1a3603757ef40b31fdc4d716adaf35bc500732ef Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 05:03:12 +0000 Subject: [PATCH 3/3] Optimize Hexagon binary ops with DMA and context struct - Introduced `struct htp_binary_context` to localize precomputed values. - Implemented double-buffered DMA pipeline in `binary_job_f32_per_thread`. - Enabled aligned HVX operations via VTCM alignment. - Added support for row and inner broadcasting. - Fixed spad addressing in `binary_add_id_job_f32_per_thread`. Co-authored-by: max-krasnyansky <1380796+max-krasnyansky@users.noreply.github.com> --- ggml/src/ggml-hexagon/htp/binary-ops.c | 319 +++++++++++++++++++------ 1 file changed, 249 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index de22afe4..a8261ac5 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -22,6 +22,22 @@ typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa }; +struct htp_binary_context { + struct htp_ops_context * octx; + + uint32_t src0_nrows_per_thread; + + struct fastdiv_values src0_div1; + struct fastdiv_values src0_div2; + struct fastdiv_values src0_div3; + struct fastdiv_values src0_div21; + + struct fastdiv_values src1_div1; + struct fastdiv_values src1_div2; + struct fastdiv_values src1_div3; + struct fastdiv_values src1_div21; +}; + #define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ @@ -60,19 +76,57 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f3 \ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; -static void binary_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - enum htp_op op) { - htp_binary_preamble; +#define htp_binary_context_preamble \ + struct htp_ops_context * octx = bctx->octx; \ + const struct htp_tensor * src0 = &octx->src0; \ + const struct htp_tensor * src1 = &octx->src1; \ + const struct htp_tensor * src2 = &octx->src2; \ + struct htp_tensor * dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t src0_nrows_per_thread = bctx->src0_nrows_per_thread; + +static void binary_job_f32_per_thread(struct htp_binary_context * bctx, + uint32_t nth, + uint32_t ith, + enum htp_op op, + dma_queue * dma_queue) { + htp_binary_context_preamble; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; const size_t dst_row_size = nb1; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -85,76 +139,169 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; + uint8_t * src0_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * src1_spad_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; - uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); - uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); + int BLOCK = src0_spad_half / src0_row_size_aligned; + if (BLOCK == 0) return; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + const uint8_t * data_src0 = (const uint8_t *) src0->data; + const uint8_t * data_src1 = (const uint8_t *) src1->data; + uint8_t * data_dst = (uint8_t *) dst->data; + + hvx_elemwise_f32_func func_HVX = func_table_HVX_opt[op]; + + uint32_t next_ir = src0_start_row; + int spad_idx = 0; const uint32_t ne02_ne01 = ne02 * ne01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + // Prime 2 blocks + for (int k = 0; k < 2 && next_ir < src0_end_row; k++) { + uint32_t block_size = MIN(BLOCK, src0_end_row - next_ir); + + uint32_t i03 = fastdiv(next_ir, &bctx->src0_div21); + uint32_t i02 = fastdiv(next_ir - i03 * ne02_ne01, &bctx->src0_div1); + uint32_t i01 = (next_ir - i03 * ne02_ne01 - i02 * ne01); + block_size = MIN(block_size, ne01 - i01); + + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_div3); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_div2); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_div1); + + const uint8_t * s1_addr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; + size_t s1_stride = nb11; + if (ne11 == 1 && ne01 > 1) { + s1_stride = 0; + } else if (ne11 != ne01) { + block_size = 1; + s1_stride = 0; + } - const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); - const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); - const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); + const uint8_t * s0_addr = data_src0 + next_ir * src0_row_size; - const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_base + spad_idx * dst_spad_half), + dst_row_size, dst_row_size_aligned, 0); - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1); - } + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_base + spad_idx * src0_spad_half, s0_addr), + src0_row_size_aligned, src0_row_size, block_size); + + size_t s1_width = ne10 * sizeof(float); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_base + spad_idx * src1_spad_half, s1_addr), + src1_row_size_aligned, s1_stride, s1_width, block_size); + + next_ir += block_size; + spad_idx = (spad_idx + 1) % 2; + } + + spad_idx = 0; + for (uint32_t proc_ir = src0_start_row; proc_ir < src0_end_row; ) { + uint32_t block_size = MIN(BLOCK, src0_end_row - proc_ir); + + uint32_t i03 = fastdiv(proc_ir, &bctx->src0_div21); + uint32_t i02 = fastdiv(proc_ir - i03 * ne02_ne01, &bctx->src0_div1); + uint32_t i01 = (proc_ir - i03 * ne02_ne01 - i02 * ne01); + block_size = MIN(block_size, ne01 - i01); + + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_div3); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_div2); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_div1); + if (ne11 != ne01 && !(ne11 == 1 && ne01 > 1)) block_size = 1; + + dma_ptr dp_dst = dma_queue_pop(dma_queue); + dma_ptr dp_s0 = dma_queue_pop(dma_queue); + dma_ptr dp_s1 = dma_queue_pop(dma_queue); + + float * s0_buf = (float*)dp_s0.dst; + float * s1_buf = (float*)dp_s1.dst; + float * dst_buf = (float*)dp_dst.src; + + uint32_t nr0 = ne00 / ne10; + for (uint32_t b = 0; b < block_size; b++) { + float * p_s0 = s0_buf + b * (src0_row_size_aligned/4); + float * p_s1 = s1_buf + b * (src1_row_size_aligned/4); + float * p_d = dst_buf + b * (dst_row_size_aligned/4); + + if (nr0 > 1) { + float val = *p_s1; + hvx_splat_f32_a((uint8_t*)p_s1, val, nr0); + } + func_HVX((uint8_t*)p_d, (const uint8_t*)p_s0, (const uint8_t*)p_s1, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0); - } else { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); - } + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + proc_ir * dst_row_size, dst_buf), + dst_row_size, dst_row_size_aligned, block_size); + + if (next_ir < src0_end_row) { + uint32_t pb_size = MIN(BLOCK, src0_end_row - next_ir); + + uint32_t ni03 = fastdiv(next_ir, &bctx->src0_div21); + uint32_t ni02 = fastdiv(next_ir - ni03 * ne02_ne01, &bctx->src0_div1); + uint32_t ni01 = (next_ir - ni03 * ne02_ne01 - ni02 * ne01); + pb_size = MIN(pb_size, ne01 - ni01); + + uint32_t ni13 = fastmodulo(ni03, ne13, &bctx->src1_div3); + uint32_t ni12 = fastmodulo(ni02, ne12, &bctx->src1_div2); + uint32_t ni11 = fastmodulo(ni01, ne11, &bctx->src1_div1); + + const uint8_t * ns1_addr = data_src1 + ni13 * nb13 + ni12 * nb12 + ni11 * src1_row_size; + size_t ns1_stride = nb11; + if (ne11 == 1 && ne01 > 1) { + ns1_stride = 0; + } else if (ne11 != ne01) { + pb_size = 1; + ns1_stride = 0; } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); + + const uint8_t * ns0_addr = data_src0 + next_ir * src0_row_size; + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_base + spad_idx * dst_spad_half), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_base + spad_idx * src0_spad_half, ns0_addr), + src0_row_size_aligned, src0_row_size, pb_size); + + size_t ns1_width = ne10 * sizeof(float); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_base + spad_idx * src1_spad_half, ns1_addr), + src1_row_size_aligned, ns1_stride, ns1_width, pb_size); + + next_ir += pb_size; + spad_idx = (spad_idx + 1) % 2; } - src0_ptr += src0_row_size; - dst_ptr += dst_row_size; + proc_ir += block_size; } + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, + FARF(HIGH, "binary-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, +static void binary_add_id_job_f32_per_thread(struct htp_binary_context * bctx, uint32_t nth, uint32_t ith, hvx_elemwise_f32_func func_HVX) { - htp_binary_preamble; + htp_binary_context_preamble; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; @@ -170,6 +317,8 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, return; } + uint8_t * spad_data = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -180,8 +329,8 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, const uint32_t ne02_ne01 = ne02 * ne01; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { // src0 indices - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); + const uint32_t i03 = fastdiv(ir, &bctx->src0_div21); + const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &bctx->src0_div1); const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); // src1 indices @@ -219,17 +368,18 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, } static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; switch (octx->op) { case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: - binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); + binary_job_f32_per_thread(bctx, n, i, octx->op, octx->ctx->dma[i]); break; case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); + binary_add_id_job_f32_per_thread(bctx, n, i, hvx_add_f32); break; default: @@ -281,10 +431,37 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { const size_t src1_row_size = src1->nb[1]; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // If inner broadcast, we might need larger src1 spad to splat + size_t src1_spad_req = src1_row_size_aligned; + if (src0->ne[0] / src1->ne[0] > 1) { + src1_spad_req = MAX(src1_spad_req, src0_row_size_aligned); + } + + // Determine row allocation in VTCM + size_t spad_per_row_aligned = src0_row_size_aligned + src1_spad_req + dst_row_size_aligned; + size_t total_rows_vtcm = octx->ctx->vtcm_size / spad_per_row_aligned; + size_t rows_per_thread = total_rows_vtcm / n_threads; + + // We want at least 2 rows per thread for double buffering + if (rows_per_thread < 2) { + // Fallback: use minimum but check total size later + rows_per_thread = 2; + } + // Make even for simpler ping-pong + rows_per_thread &= ~1; + if (rows_per_thread < 2) rows_per_thread = 2; + + octx->src0_spad.size_per_thread = rows_per_thread * src0_row_size_aligned; + octx->src1_spad.size_per_thread = rows_per_thread * src1_spad_req; + octx->dst_spad.size_per_thread = rows_per_thread * dst_row_size_aligned; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; @@ -308,19 +485,21 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); - octx->src0_div3 = init_fastdiv_values(src0->ne[3]); - octx->src0_div2 = init_fastdiv_values(src0->ne[2]); - octx->src0_div1 = init_fastdiv_values(src0->ne[1]); + bctx.src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); + bctx.src0_div3 = init_fastdiv_values(src0->ne[3]); + bctx.src0_div2 = init_fastdiv_values(src0->ne[2]); + bctx.src0_div1 = init_fastdiv_values(src0->ne[1]); - octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); - octx->src1_div3 = init_fastdiv_values(src1->ne[3]); - octx->src1_div2 = init_fastdiv_values(src1->ne[2]); - octx->src1_div1 = init_fastdiv_values(src1->ne[1]); + bctx.src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); + bctx.src1_div3 = init_fastdiv_values(src1->ne[3]); + bctx.src1_div2 = init_fastdiv_values(src1->ne[2]); + bctx.src1_div1 = init_fastdiv_values(src1->ne[1]); - worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, &bctx, n_jobs); } return err;