From 97e26fdc676817fb5e6706b2714dc9d2afdb5471 Mon Sep 17 00:00:00 2001 From: wisman <2659530589@qq.com> Date: Sat, 2 Aug 2025 03:35:11 +0000 Subject: [PATCH] Preliminary SSSE3 support --- smallthinker/powerinfer/libaz/az/core/fp16.h | 1 + .../powerinfer/libaz/az/cpu/quant_types.cpp | 22 +- .../powerinfer/libaz/az/cpu/vec_dot.cpp | 209 +++++++++++++++++- .../powerinfer-cpu/include/convert.hpp | 90 +++++++- .../include/powerinfer-cpu-data.hpp | 11 +- .../powerinfer/powerinfer-cpu/src/vec_dot.hpp | 195 +++++++++++++++- 6 files changed, 508 insertions(+), 20 deletions(-) diff --git a/smallthinker/powerinfer/libaz/az/core/fp16.h b/smallthinker/powerinfer/libaz/az/core/fp16.h index 073bdf70..6c43441f 100644 --- a/smallthinker/powerinfer/libaz/az/core/fp16.h +++ b/smallthinker/powerinfer/libaz/az/core/fp16.h @@ -6,6 +6,7 @@ #include "az/core/intrinsics.hpp" #include "stdint.h" +#include #if defined(__cplusplus) extern "C" { diff --git a/smallthinker/powerinfer/libaz/az/cpu/quant_types.cpp b/smallthinker/powerinfer/libaz/az/cpu/quant_types.cpp index 72488806..e77fb4cc 100644 --- a/smallthinker/powerinfer/libaz/az/cpu/quant_types.cpp +++ b/smallthinker/powerinfer/libaz/az/cpu/quant_types.cpp @@ -187,7 +187,27 @@ void quantize_row_q8_0(block_q8_0 *out, const float *in, size_t n) { #endif // __AVX2__ } #else - abort(); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (size_t j = 0; j < block_q8_0::block_size; j++) { + const float v = in[i*block_q8_0::block_size + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + out[i].d = AZ_FP32_TO_FP16(d); + + for (size_t j = 0; j < block_q8_0::block_size; ++j) { + const float x0 = in[i*block_q8_0::block_size + j]*id; + + out[i].qs[j] = roundf(x0); + } + } #endif } diff --git a/smallthinker/powerinfer/libaz/az/cpu/vec_dot.cpp b/smallthinker/powerinfer/libaz/az/cpu/vec_dot.cpp index 7a5685b6..8d9c6c56 100644 --- a/smallthinker/powerinfer/libaz/az/cpu/vec_dot.cpp +++ b/smallthinker/powerinfer/libaz/az/cpu/vec_dot.cpp @@ -9,12 +9,13 @@ static int sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); #endif -#if defined(__AVX2__) +// Ref:llama.cpp(https://github.com/ggml-org/llama.cpp) ggml/src/ggml-cpu/arch/x86/quants.c +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) namespace { - // multiply int8_t, add results pairwise twice -inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors const __m128i ax = _mm_sign_epi8(x, x); // Sign the values of the y vectors @@ -25,8 +26,9 @@ inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { return _mm_madd_epi16(ones, dot); } +#if __AVX__ || __AVX2__ || __AVX512F__ // horizontally add 8 floats -inline float hsum_float_8(const __m256 x) { +static inline float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); res = _mm_add_ps(res, _mm256_castps256_ps128(x)); res = _mm_add_ps(res, _mm_movehl_ps(res, res)); @@ -34,27 +36,63 @@ inline float hsum_float_8(const __m256 x) { return _mm_cvtss_f32(res); } +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -inline __m256i bytes_from_nibbles_32(const uint8_t *rsi) { +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8(0xF); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); return _mm256_and_si256(lowMask, bytes); } // add int16_t pairwise and return as float vector -inline __m256 sum_i16_pairs_float(const __m256i x) { +static inline __m256 sum_i16_pairs_float(const __m256i x) { const __m256i ones = _mm256_set1_epi16(1); const __m256i summed_pairs = _mm256_madd_epi16(ones, x); return _mm256_cvtepi32_ps(summed_pairs); } -inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); @@ -62,7 +100,8 @@ inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { #endif } -inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #if __AVXVNNIINT8__ const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); @@ -76,9 +115,155 @@ inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #endif } +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); } -#endif // __AVX2__ +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors +static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, + const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { + const __m128i mone = _mm_set1_epi16(1); + + const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); + const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); + return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); +} + +// quad fp16 delta calculation +static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { + // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +} +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) namespace az::cpu { diff --git a/smallthinker/powerinfer/powerinfer-cpu/include/convert.hpp b/smallthinker/powerinfer/powerinfer-cpu/include/convert.hpp index 29834b70..50b1ca3f 100644 --- a/smallthinker/powerinfer/powerinfer-cpu/include/convert.hpp +++ b/smallthinker/powerinfer/powerinfer-cpu/include/convert.hpp @@ -146,9 +146,97 @@ inline void quantize_row_q8_0(const float *__restrict__ x, void *__restrict__ vy _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); #endif // __AVX2__ } +#elif defined(__SSSE3__)//SSSE3 does not provide a significant speedup over scalar operations in i9 14900 + for (int i = 0; i < nb; i++) { + __m128 v0 = _mm_loadu_ps(x + 0); + __m128 v1 = _mm_loadu_ps(x + 4); + __m128 v2 = _mm_loadu_ps(x + 8); + __m128 v3 = _mm_loadu_ps(x + 12); + __m128 v4 = _mm_loadu_ps(x + 16); + __m128 v5 = _mm_loadu_ps(x + 20); + __m128 v6 = _mm_loadu_ps(x + 24); + __m128 v7 = _mm_loadu_ps(x + 28); + x += 32; + + const __m128 sign_mask = _mm_set1_ps(-0.0f); + + + __m128 max_abs_0 = _mm_max_ps(_mm_andnot_ps(sign_mask, v0), _mm_andnot_ps(sign_mask, v1)); + __m128 max_abs_1 = _mm_max_ps(_mm_andnot_ps(sign_mask, v2), _mm_andnot_ps(sign_mask, v3)); + __m128 max_abs_2 = _mm_max_ps(_mm_andnot_ps(sign_mask, v4), _mm_andnot_ps(sign_mask, v5)); + __m128 max_abs_3 = _mm_max_ps(_mm_andnot_ps(sign_mask, v6), _mm_andnot_ps(sign_mask, v7)); + + max_abs_0 = _mm_max_ps(max_abs_0, max_abs_1); + max_abs_2 = _mm_max_ps(max_abs_2, max_abs_3); + + __m128 max_abs = _mm_max_ps(max_abs_0, max_abs_2); + + + max_abs = _mm_max_ps(max_abs, _mm_shuffle_ps(max_abs, max_abs, _MM_SHUFFLE(2, 3, 0, 1))); + max_abs = _mm_max_ps(max_abs, _mm_shuffle_ps(max_abs, max_abs, _MM_SHUFFLE(1, 0, 3, 2))); + const float max_scalar = _mm_cvtss_f32(max_abs); + + + const float d = max_scalar / 127.0f; + y[i].d = POWERINFER_FP32_TO_FP16(d); + const float id = (max_scalar != 0.0f) ? 127.0f / max_scalar : 0.0f; + const __m128 mul = _mm_set1_ps(id); + + + v0 = _mm_mul_ps(v0, mul); + v1 = _mm_mul_ps(v1, mul); + v2 = _mm_mul_ps(v2, mul); + v3 = _mm_mul_ps(v3, mul); + v4 = _mm_mul_ps(v4, mul); + v5 = _mm_mul_ps(v5, mul); + v6 = _mm_mul_ps(v6, mul); + v7 = _mm_mul_ps(v7, mul); + + + __m128i i0 = _mm_cvtps_epi32(v0); + __m128i i1 = _mm_cvtps_epi32(v1); + __m128i i2 = _mm_cvtps_epi32(v2); + __m128i i3 = _mm_cvtps_epi32(v3); + __m128i i4 = _mm_cvtps_epi32(v4); + __m128i i5 = _mm_cvtps_epi32(v5); + __m128i i6 = _mm_cvtps_epi32(v6); + __m128i i7 = _mm_cvtps_epi32(v7); + + + __m128i p0 = _mm_packs_epi32(i0, i1); + __m128i p1 = _mm_packs_epi32(i2, i3); + __m128i p2 = _mm_packs_epi32(i4, i5); + __m128i p3 = _mm_packs_epi32(i6, i7); + + + __m128i q0 = _mm_packs_epi16(p0, p1); + __m128i q1 = _mm_packs_epi16(p2, p3); + + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), q0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), q1); + } #else - POWERINFER_ABORT("unsupported platform"); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + if(fabsf(v)>amax)amax=fabs(v); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = POWERINFER_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } #endif } diff --git a/smallthinker/powerinfer/powerinfer-cpu/include/powerinfer-cpu-data.hpp b/smallthinker/powerinfer/powerinfer-cpu/include/powerinfer-cpu-data.hpp index 6cfbf980..664290c9 100644 --- a/smallthinker/powerinfer/powerinfer-cpu/include/powerinfer-cpu-data.hpp +++ b/smallthinker/powerinfer/powerinfer-cpu/include/powerinfer-cpu-data.hpp @@ -52,13 +52,20 @@ extern "C" HOST_API float ggml_lookup_fp16_to_fp32(ggml_fp16_t f); } #else #include - + #include "az/core/fp16.h" #ifdef _MSC_VER inline float ggml_compute_fp16_to_fp32(const ggml_fp16_t x) { return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))); } inline ggml_fp16_t ggml_compute_fp32_to_fp16(const float x) { return _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0); } - #else // NOT _MSC_VER + #elif defined(__F16C__) // NOT _MSC_VER inline float ggml_compute_fp16_to_fp32(const ggml_fp16_t x) { return _cvtsh_ss(x); } inline ggml_fp16_t ggml_compute_fp32_to_fp16(const float x) { return _cvtss_sh(x, 0); } + #else + inline float ggml_compute_fp16_to_fp32(const ggml_fp16_t x) { + az_fp16_t tmp; + tmp.value=x; + return az_compute_fp16_to_fp32(tmp); + } + inline ggml_fp16_t ggml_compute_fp32_to_fp16(const float x) { return (ggml_fp16_t)az_compute_fp32_to_fp16(x).value; } #endif // _MSC_VER #endif diff --git a/smallthinker/powerinfer/powerinfer-cpu/src/vec_dot.hpp b/smallthinker/powerinfer/powerinfer-cpu/src/vec_dot.hpp index 46786866..62507cd9 100644 --- a/smallthinker/powerinfer/powerinfer-cpu/src/vec_dot.hpp +++ b/smallthinker/powerinfer/powerinfer-cpu/src/vec_dot.hpp @@ -14,8 +14,10 @@ static int sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); using vec_dot_func_t = void (*)(int n, float * s, const void * vx, const void * vy); -#if defined(__AVX2__) +// Ref:llama.cpp(https://github.com/ggml-org/llama.cpp) ggml/src/ggml-cpu/arch/x86/quants.c +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors @@ -28,6 +30,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { return _mm_madd_epi16(ones, dot); } +#if __AVX__ || __AVX2__ || __AVX512F__ // horizontally add 8 floats static inline float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -37,12 +40,43 @@ static inline float hsum_float_8(const __m256 x) { return _mm_cvtss_f32(res); } +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); const __m256i lowMask = _mm256_set1_epi8( 0xF ); return _mm256_and_si256(lowMask, bytes); } @@ -55,10 +89,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); @@ -66,6 +104,7 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) #endif } +// multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #if __AVXVNNIINT8__ const __m256i zero = _mm256_setzero_si256(); @@ -80,7 +119,155 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #endif } -#endif // __AVX2__ +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors +static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, + const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { + const __m128i mone = _mm_set1_epi16(1); + + const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); + const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); + return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); +} + +// quad fp16 delta calculation +static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { + // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + #if defined (__ARM_NEON)