Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ggml/src/ggml-hexagon/ggml-hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,21 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
return true;
}

static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * dst = op; // indices

if (src0->type != GGML_TYPE_F32) {
return false;
}

if (dst->type != GGML_TYPE_I32) {
return false;
}

return true;
}

static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];

Expand Down Expand Up @@ -2316,6 +2331,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
return n_bufs;
}

static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_ARGSORT;
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));

size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);

return n_bufs;
}

template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
Expand Down Expand Up @@ -2564,6 +2590,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;

case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;

default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
Expand Down Expand Up @@ -2968,6 +2998,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_cpy(sess, op);
break;

case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op);
break;

default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-hexagon/htp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include_directories(
${HEXAGON_SDK_ROOT}/incs
${HEXAGON_SDK_ROOT}/incs/stddef
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}
Expand All @@ -28,6 +29,7 @@ add_library(${HTP_LIB} SHARED
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
argsort-ops.c
)

target_compile_definitions(${HTP_LIB} PRIVATE
Expand Down
313 changes: 313 additions & 0 deletions ggml/src/ggml-hexagon/htp/argsort-ops.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <HAP_farf.h>
#include <HAP_perf.h>

#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"

#include "hvx-utils.h"
#include "hex-dma.h"

#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"

#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif

struct htp_argsort_context {
struct htp_ops_context * octx;
uint32_t nrows_per_thread;
struct fastdiv_values div_ne01;
struct fastdiv_values div_ne02_ne01;
};

// Vectorized sort implementation since std::sort is not available.
// Sorts values and mirrors swaps to indices.
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;

int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;

HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
HVX_Vector one_vec = Q6_V_vsplat_R(1);
HVX_Vector zero_vec = Q6_V_vzero();

while (i <= j) {
// Vectorized scan for i
while (i <= j) {
// Check if we have at least one full vector
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(pivot_vec, vals_vec);

// If all elements are < pivot, we can skip this whole block
// To check "all", we count matches.
HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
if (hvx_vec_get_i32(sum) == 32) {
i += 32;
continue;
}
}

// Scalar fallback / cleanup
if (values[i] < pivot) {
i++;
} else {
break;
}
}

// Vectorized scan for j
while (i <= j) {
if (j - 32 >= i) {
// Load 32 elements ending at j.
// Since we want `values[j] > pivot`, let's load from j-31 to j.
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(vals_vec, pivot_vec);

HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
if (hvx_vec_get_i32(sum) == 32) {
j -= 32;
continue;
}
}

if (values[j] > pivot) {
j--;
} else {
break;
}
}

if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;

int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}

if (left < j) quicksort_values_indices_asc(values, indices, left, j);
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
}

static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;

int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;

HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
HVX_Vector one_vec = Q6_V_vsplat_R(1);
HVX_Vector zero_vec = Q6_V_vzero();

while (i <= j) {
// Vectorized scan for i (values[i] > pivot)
while (i <= j) {
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(vals_vec, pivot_vec);

HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
if (hvx_vec_get_i32(sum) == 32) {
i += 32;
continue;
}
}

if (values[i] > pivot) {
i++;
} else {
break;
}
}

// Vectorized scan for j (values[j] < pivot)
while (i <= j) {
if (j - 32 >= i) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(pivot_vec, vals_vec);

HVX_Vector matches = Q6_V_vmux_QVV(pred, one_vec, zero_vec);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
if (hvx_vec_get_i32(sum) == 32) {
j -= 32;
continue;
}
}

if (values[j] < pivot) {
j--;
} else {
break;
}
}

if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;

int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}

if (left < j) quicksort_values_indices_desc(values, indices, left, j);
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
}

static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
struct htp_ops_context * octx = actx->octx;

// Unpack context
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * dst = &octx->dst;

// Scratchpad memory
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;

// Dimensions
uint32_t ne00 = src0->ne[0];
uint32_t ne01 = src0->ne[1];
uint32_t ne02 = src0->ne[2];
uint32_t ne03 = src0->ne[3];

uint32_t nb01 = src0->nb[1];
uint32_t nb02 = src0->nb[2];
uint32_t nb03 = src0->nb[3];

uint32_t nb1 = dst->nb[1];
uint32_t nb2 = dst->nb[2];
uint32_t nb3 = dst->nb[3];

// Sort order
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];

// Rows to process
uint32_t total_rows = ne01 * ne02 * ne03;
uint32_t rows_per_thread = actx->nrows_per_thread;
uint32_t start_row = rows_per_thread * i;
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);

// Scratchpad layout:
// We need space for one row of float data (values) and one row of int32 indices.
// values: ne00 * sizeof(float)
// indices: ne00 * sizeof(int32_t)
// Padded to 128 bytes.

size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
float * values_buf = (float *) spad;
int32_t * indices_buf = (int32_t *) (spad + values_size);

for (uint32_t r = start_row; r < end_row; r++) {
// Calculate indices for 3D iteration flattened using fastdiv
// uint32_t i03 = r / (ne02 * ne01);
// uint32_t rem = r % (ne02 * ne01);
// uint32_t i02 = rem / ne01;
// uint32_t i01 = rem % ne01;

uint32_t i03 = fastdiv(r, &actx->div_ne02_ne01);
uint32_t rem = fastmodulo(r, ne02 * ne01, &actx->div_ne02_ne01);
uint32_t i02 = fastdiv(rem, &actx->div_ne01);
uint32_t i01 = rem - i02 * ne01;

uint32_t src_offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
uint32_t dst_offset = i03 * nb3 + i02 * nb2 + i01 * nb1;

uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;

// Prefetch and Copy row data to VTCM
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);

// Use vector copy if available/efficient, handles unaligned
hvx_copy_f32_uu((uint8_t*)values_buf, src_ptr, ne00);

// Initialize indices
for (uint32_t j = 0; j < ne00; j++) {
indices_buf[j] = j;
}

// Sort values and mirror swaps to indices
if (order == GGML_SORT_ORDER_ASC) {
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
} else {
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
}

// Copy indices back to DDR
// Indices are 32-bit integers, effectively same as float for copy purposes size-wise
hvx_copy_f32_uu(dst_ptr, (const uint8_t *) indices_buf, ne00);
}
}

int op_argsort(struct htp_ops_context * octx) {
// Check supported types
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}

// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
size_t spad_per_thread = values_size + indices_size;

// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);

size_t total_spad_size = spad_per_thread * octx->n_threads;

if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}

octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src0_spad.size = total_spad_size;
octx->src0_spad.size_per_thread = spad_per_thread;

FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);

uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
uint32_t n_jobs = MIN(total_rows, octx->n_threads);

struct htp_argsort_context actx;
actx.octx = octx;
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
// Initialize fastdiv values
actx.div_ne01 = init_fastdiv_values(octx->src0.ne[1]);
actx.div_ne02_ne01 = init_fastdiv_values(octx->src0.ne[2] * octx->src0.ne[1]);

// Run jobs
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);

return HTP_STATUS_OK;
}
Loading
Loading