From 5aaf5de155d154bccb98f71536c5c29c2a2f4b9d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 3 Feb 2026 22:21:19 -0800 Subject: [PATCH 1/2] hexagon: add ARGSORT op --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 34 ++++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 2 + ggml/src/ggml-hexagon/htp/argsort-ops.c | 211 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-msg.h | 17 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 2 - ggml/src/ggml-hexagon/htp/main.c | 47 +++++ 7 files changed, 296 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/argsort-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4f0a1620..2f92c761 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2111,6 +2111,21 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2316,6 +2331,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_ARGSORT; + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2564,6 +2590,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_ARGSORT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2968,6 +2998,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index e8ef2030..5922dcc8 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -28,6 +29,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + argsort-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 00000000..d2053173 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,211 @@ +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne01; + struct fastdiv_values div_ne02_ne01; +}; + +// Scalar sort implementation since std::sort is not available. +// Sorts indices based on values. +static void quicksort_indices_asc(int32_t * indices, const float * data, int left, int right) { + if (left >= right) return; + + int pivot_idx = indices[(left + right) / 2]; + float pivot = data[pivot_idx]; + int i = left; + int j = right; + + while (i <= j) { + while (data[indices[i]] < pivot) i++; + while (data[indices[j]] > pivot) j--; + if (i <= j) { + int32_t tmp = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp; + i++; + j--; + } + } + + if (left < j) quicksort_indices_asc(indices, data, left, j); + if (i < right) quicksort_indices_asc(indices, data, i, right); +} + +static void quicksort_indices_desc(int32_t * indices, const float * data, int left, int right) { + if (left >= right) return; + + int pivot_idx = indices[(left + right) / 2]; + float pivot = data[pivot_idx]; + int i = left; + int j = right; + + while (i <= j) { + while (data[indices[i]] > pivot) i++; + while (data[indices[j]] < pivot) j--; + if (i <= j) { + int32_t tmp = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp; + i++; + j--; + } + } + + if (left < j) quicksort_indices_desc(indices, data, left, j); + if (i < right) quicksort_indices_desc(indices, data, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + uint32_t nb02 = src0->nb[2]; + uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + uint32_t nb2 = dst->nb[2]; + uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + // Calculate indices for 3D iteration flattened using fastdiv + // uint32_t i03 = r / (ne02 * ne01); + // uint32_t rem = r % (ne02 * ne01); + // uint32_t i02 = rem / ne01; + // uint32_t i01 = rem % ne01; + + uint32_t i03 = fastdiv(r, &actx->div_ne02_ne01); + uint32_t rem = fastmodulo(r, ne02 * ne01, &actx->div_ne02_ne01); + uint32_t i02 = fastdiv(rem, &actx->div_ne01); + uint32_t i01 = rem - i02 * ne01; + + uint32_t src_offset = i03 * nb03 + i02 * nb02 + i01 * nb01; + uint32_t dst_offset = i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + // Prefetch and Copy row data to VTCM + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + + // Use vector copy if available/efficient, handles unaligned + hvx_copy_f32_uu((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort indices based on values + if (order == GGML_SORT_ORDER_ASC) { + quicksort_indices_asc(indices_buf, values_buf, 0, ne00 - 1); + } else { + quicksort_indices_desc(indices_buf, values_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + // Indices are 32-bit integers, effectively same as float for copy purposes size-wise + hvx_copy_f32_uu(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + // Initialize fastdiv values + actx.div_ne01 = init_fastdiv_values(octx->src0.ne[1]); + actx.div_ne02_ne01 = init_fastdiv_values(octx->src0.ne[2] * octx->src0.ne[1]); + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f49e8ee4..6af61782 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -64,6 +64,7 @@ enum htp_op { HTP_OP_SCALE = 16, HTP_OP_GET_ROWS = 17, HTP_OP_CPY = 18, + HTP_OP_ARGSORT = 19, INVALID }; @@ -103,22 +104,6 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 0; } -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 602a2775..e1725151 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -98,5 +98,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 6b617b76..ae0dbed0 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const HVX_Vector zero = Q6_V_vsplat_R(0); \ - \ const uint32_t elem_size = sizeof(__fp16); \ const uint32_t epv = 128 / elem_size; \ const uint32_t nvec = n / epv; \ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e28a67a9..0ac7c557 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -1035,6 +1074,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; From f70983d8c2749930bf4acd2381a832b6e396d99b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:21:00 +0000 Subject: [PATCH 2/2] Optimize Hexagon argsort with vectorized partition Replaced the scalar Quicksort implementation with a vectorized version using HVX intrinsics. - Changed sorting strategy to direct sort on values buffer with mirrored index swaps for better vectorization. - Added `hvx_vec_get_i32` to `hvx-base.h`. - Implemented partition loop using vector comparisons and reduction-based "all check" (workaround for missing `Q6_Q_all_P`). Co-authored-by: max-krasnyansky <1380796+max-krasnyansky@users.noreply.github.com> --- ggml/src/ggml-hexagon/htp/argsort-ops.c | 148 ++++++++++++++++++++---- ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + 2 files changed, 131 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index d2053173..66359c7e 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -26,54 +26,156 @@ struct htp_argsort_context { struct fastdiv_values div_ne02_ne01; }; -// Scalar sort implementation since std::sort is not available. -// Sorts indices based on values. -static void quicksort_indices_asc(int32_t * indices, const float * data, int left, int right) { +// Vectorized sort implementation since std::sort is not available. +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { if (left >= right) return; - int pivot_idx = indices[(left + right) / 2]; - float pivot = data[pivot_idx]; + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; int i = left; int j = right; + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + HVX_Vector one_vec = Q6_V_vsplat_R(1); + HVX_Vector zero_vec = Q6_V_vzero(); + while (i <= j) { - while (data[indices[i]] < pivot) i++; - while (data[indices[j]] > pivot) j--; + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(pivot_vec, vals_vec); + + // If all elements are < pivot, we can skip this whole block + // To check "all", we count matches. + HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + if (hvx_vec_get_i32(sum) == 32) { + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(vals_vec, pivot_vec); + + HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + if (hvx_vec_get_i32(sum) == 32) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + if (i <= j) { - int32_t tmp = indices[i]; + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; indices[i] = indices[j]; - indices[j] = tmp; + indices[j] = tmp_idx; i++; j--; } } - if (left < j) quicksort_indices_asc(indices, data, left, j); - if (i < right) quicksort_indices_asc(indices, data, i, right); + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); } -static void quicksort_indices_desc(int32_t * indices, const float * data, int left, int right) { +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { if (left >= right) return; - int pivot_idx = indices[(left + right) / 2]; - float pivot = data[pivot_idx]; + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; int i = left; int j = right; + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + HVX_Vector one_vec = Q6_V_vsplat_R(1); + HVX_Vector zero_vec = Q6_V_vzero(); + while (i <= j) { - while (data[indices[i]] > pivot) i++; - while (data[indices[j]] < pivot) j--; + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(vals_vec, pivot_vec); + + HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + if (hvx_vec_get_i32(sum) == 32) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(pivot_vec, vals_vec); + + HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + if (hvx_vec_get_i32(sum) == 32) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + if (i <= j) { - int32_t tmp = indices[i]; + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; indices[i] = indices[j]; - indices[j] = tmp; + indices[j] = tmp_idx; i++; j--; } } - if (left < j) quicksort_indices_desc(indices, data, left, j); - if (i < right) quicksort_indices_desc(indices, data, i, right); + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); } static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { @@ -149,11 +251,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { indices_buf[j] = j; } - // Sort indices based on values + // Sort values and mirror swaps to indices if (order == GGML_SORT_ORDER_ASC) { - quicksort_indices_asc(indices_buf, values_buf, 0, ne00 - 1); + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); } else { - quicksort_indices_desc(indices_buf, values_buf, 0, ne00 - 1); + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); } // Copy indices back to DDR diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ffa6e18e..12a1b7f1 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) { return x; } +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);