diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..efda3f8 --- /dev/null +++ b/.clang-format @@ -0,0 +1,24 @@ +--- +# Defaults for all languages. +BasedOnStyle: Google + +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +SortIncludes: false +DerivePointerAlignment: false +# Avoid adding spaces between tokens in GSL_SUPPRESS arguments. +# E.g., don't change "GSL_SUPPRESS(r.11)" to "GSL_SUPPRESS(r .11)". +WhitespaceSensitiveMacros: ["GSL_SUPPRESS"] + +# if you want to customize when working locally see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for options. +# See ReformatSource.ps1 for a script to update all source according to the current options in this file. +# e.g. customizations to use Allman bracing and more indenting. +# AccessModifierOffset: -2 +# BreakBeforeBraces: Allman +# CompactNamespaces: false +# IndentCaseLabels: true +# IndentWidth: 4 +# NamespaceIndentation: All + +... diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 4bee28b..78c6b56 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -7,27 +7,27 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true - + jobs: Linux_arm64_gcc_release: - runs-on: ["self-hosted", "1ES.Pool=mlas-linux-ARM64-CPU"] + runs-on: ubuntu-24.04-arm steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow Linux_x64_gcc_ubuntu24_release_no_ort: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_no_ort_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_no_ort_workflow Linux_x64_gcc_ubuntu24_release: runs-on: ubuntu-24.04 @@ -43,16 +43,16 @@ jobs: config-file: ./.github/codeql/codeql-config.yml languages: 'cpp' - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v3 with: category: "/language:cpp" output: sarif-results upload: failure-only - + - name: filter-sarif uses: advanced-security/filter-sarif@v1 with: @@ -68,92 +68,92 @@ jobs: uses: github/codeql-action/upload-sarif@v3 with: sarif_file: sarif-results/cpp.sarif - + Linux_x64_gcc_ubuntu22_release: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow Linux_x64_gcc_ubuntu24_debug: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_debug_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_debug_workflow Linux_x64_clang_ubuntu24_debug: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_clang_debug_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_clang_debug_workflow Linux_x64_gcc_ubuntu24_debug_asan: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_debug_asan_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_debug_asan_workflow Linux_wasm_debug_asan: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O0 -g -fsanitize=address" CXXFLAGS="-O0 -g -fsanitize=address" emcmake cmake .. -DCMAKE_BUILD_TYPE=Debug -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O0 -g -fsanitize=address" CXXFLAGS="-O0 -g -fsanitize=address" emcmake cmake .. -DCMAKE_BUILD_TYPE=Debug -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON + make -j $(nproc) all + ctest --output-on-failure Linux_wasm_release: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON + make -j $(nproc) all + ctest --output-on-failure Linux_wasm_release_no_exception: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON -DMLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON -DMLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON + make -j $(nproc) all + ctest --output-on-failure diff --git a/.github/workflows/macos_ci.yml b/.github/workflows/macos_ci.yml index de92f12..d7cbed2 100644 --- a/.github/workflows/macos_ci.yml +++ b/.github/workflows/macos_ci.yml @@ -7,9 +7,9 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true - + jobs: # The following one doesn't work on macos-12. It has some compiling errors related to std::date MacOS14_arm64_release: @@ -17,15 +17,15 @@ jobs: steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset macos_arm64_release_workflow - + set -e -x + rm -rf build + cmake --workflow --preset macos_arm64_release_workflow + MacOS14_universal2_release: runs-on: macos-14 steps: - uses: actions/checkout@v4 - - run: | - set -e -x - rm -rf build - cmake --workflow --preset macos_universal2_release_workflow \ No newline at end of file + - run: |- + set -e -x + rm -rf build + cmake --workflow --preset macos_universal2_release_workflow diff --git a/.github/workflows/reusable_windows_build.yml b/.github/workflows/reusable_windows_build.yml new file mode 100644 index 0000000..224c5c6 --- /dev/null +++ b/.github/workflows/reusable_windows_build.yml @@ -0,0 +1,83 @@ +name: Reusable Windows Build + +on: + workflow_call: + inputs: + job-name: # For display purposes in the reusable workflow logs + required: true + type: string + runner-os: # New input for the runner OS + required: true + type: string + cmake-workflow-preset: + required: true + type: string + enable-codeql: + required: false + type: boolean + default: false + codeql-config-file: + required: false + type: string + default: ./.github/codeql/codeql-config.yml + codeql-sarif-output-dir: + required: false + type: string + default: sarif-results + +permissions: + actions: read + contents: read + security-events: write # Needed if CodeQL analysis runs & uploads + +jobs: + build_and_optional_analyze: + name: ${{ inputs.job-name }} + runs-on: ${{ inputs.runner-os }} # Use the input for runs-on + permissions: # Permissions moved here as they are job-specific + actions: read + contents: read + security-events: write # Needed if CodeQL analysis runs & uploads + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL (if enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/init@v3 + with: + config-file: ${{ inputs.codeql-config-file }} + languages: 'cpp' + + - name: Run CMake Workflow + run: | + cmake --workflow --preset ${{ inputs.cmake-workflow-preset }} + shell: cmd + + - name: Perform CodeQL Analysis (if enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: ${{ inputs.codeql-sarif-output-dir }} + upload: failure-only + + - name: Filter SARIF (if CodeQL enabled) + if: ${{ inputs.enable-codeql }} + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif + output: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered + + - name: Upload filtered SARIF (if CodeQL enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered + category: cpp-${{ inputs.job-name }} \ No newline at end of file diff --git a/.github/workflows/win_ci.yml b/.github/workflows/win_ci.yml new file mode 100644 index 0000000..23639c0 --- /dev/null +++ b/.github/workflows/win_ci.yml @@ -0,0 +1,72 @@ +name: Windows_CI + +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + # Win32 Jobs + Win32_debug_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Win32_Debug_NoOrt_CodeQL + runner-os: windows-2022 + cmake-workflow-preset: windows_win32_debug_no_ort_workflow + enable-codeql: true + + Win32_release_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Win32_Release_NoOrt + runner-os: windows-2022 + cmake-workflow-preset: windows_win32_release_no_ort_workflow + enable-codeql: false + + # WinX64 Jobs + WinX64_debug_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Debug_NoOrt_CodeQL + runner-os: windows-2022 + cmake-workflow-preset: windows_x64_debug_no_ort_workflow + enable-codeql: true + + WinX64_release_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Release_NoOrt + runner-os: windows-2022 + cmake-workflow-preset: windows_x64_release_no_ort_workflow + enable-codeql: false + + WinX64_release: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Release + runner-os: windows-2022 + cmake-workflow-preset: windows_x64_release_workflow + enable-codeql: false + + # Windows ARM64 Jobs + WinARM64_debug_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: WinARM64_Debug_NoOrt + runner-os: windows-11-arm # Use ARM64 runner + cmake-workflow-preset: windows_arm64_debug_no_ort_workflow + enable-codeql: false + + WinARM64_release: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: WinARM64_Release + runner-os: windows-11-arm # Use ARM64 runner + cmake-workflow-preset: windows_arm64_release_workflow + enable-codeql: false \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index a3c0bfc..9bb06b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,8 +13,7 @@ cmake_policy(SET CMP0091 NEW) cmake_policy(SET CMP0117 NEW) # Project -project(MLAS C CXX ASM) - +project(MLAS C CXX) include(CheckCXXCompilerFlag) include(CheckLanguage) @@ -28,7 +27,11 @@ if(NOT CMAKE_C_STANDARD) endif() if(NOT CMAKE_CXX_STANDARD) message("Setting C++ standard to 20") - set(CMAKE_CXX_STANDARD 20) + if(WIN32) + set(CMAKE_CXX_STANDARD 23) + else() + set(CMAKE_CXX_STANDARD 20) + endif() endif() set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -43,9 +46,63 @@ set(ONNXRUNTIME_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/src) set(ONNXRUNTIME_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) option(MLAS_ENABLE_WEBASSEMBLY_THREADS "Enable this option to create WebAssembly byte codes with multi-threads support" OFF) option(MLAS_ENABLE_WEBASSEMBLY_BROWSER_TESTS "Build all executables as html files" OFF) -option(MLAS_NO_ONNXRUNTIME "Disable ORT related code" OFF) +option(MLAS_NO_ONNXRUNTIME "Disable ONNX Runtime related code as much as possible" OFF) option(MLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING "Enable this option to turn on exception catching" OFF) +if (MSVC) + # Make sure Visual Studio sets __cplusplus macro correctly: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") + + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + enable_language(ASM) +endif() + +function(onnxruntime_configure_target target_name) + target_compile_options(${target_name} PRIVATE + $<$,$>:/MP> + $<$,$>:/MP> + ) + if(WIN32) + target_compile_options(${target_name} PRIVATE "$<$:/Zi>") + endif() +endfunction() + +function(onnxruntime_add_executable target_name) + add_executable(${target_name} ${ARGN}) + onnxruntime_configure_target(${target_name}) +endfunction() + +function(onnxruntime_add_static_library target_name) + add_library(${target_name} STATIC ${ARGN}) + onnxruntime_configure_target(${target_name}) +endfunction() + if(MLAS_ENABLE_WEBASSEMBLY_BROWSER_TESTS) #The variable cannot be set from cmake command line because otherwise emscripten's toolchain file will override it. set(CMAKE_EXECUTABLE_SUFFIX ".html") diff --git a/cmake/deps.txt b/cmake/deps.txt index 1524485..eba65ee 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -1,5 +1,5 @@ -eigen;https://gitlab.com/libeigen/eigen/-/archive/ff174f79264d3f8dc0115dea7a288f98208b694f/eigen-ff174f79264d3f8dc0115dea7a288f98208b694f.zip;e06074b74725f2677369be2eb2e97e57e2dc4353 +eigen;https://github.com/eigen-mirror/eigen/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;05b19b49e6fbb91246be711d801160528c135e34 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 -googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 +googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 -microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 diff --git a/cmake/external_deps.cmake b/cmake/external_deps.cmake index 3e6495f..ecee953 100644 --- a/cmake/external_deps.cmake +++ b/cmake/external_deps.cmake @@ -74,14 +74,22 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME AND BUILD_TESTING) set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) + onnxruntime_fetchcontent_makeavailable(kleidiai) + + # gtest and gmock - FetchContent_Declare( + onnxruntime_fetchcontent_declare( googletest URL ${DEP_URL_googletest} URL_HASH SHA1=${DEP_SHA1_googletest} + EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) - FetchContent_MakeAvailable(googletest) + onnxruntime_fetchcontent_makeavailable(googletest) #google benchmark doesn't work for Emscripten if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") message("CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") @@ -90,12 +98,13 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME AND BUILD_TESTING) # We will not need to install benchmark since we link it statically. set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.") - FetchContent_Declare( + onnxruntime_fetchcontent_declare( google_benchmark URL ${DEP_URL_google_benchmark} URL_HASH SHA1=${DEP_SHA1_google_benchmark} + EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES benchmark ) onnxruntime_fetchcontent_makeavailable(google_benchmark) endif() -endif() \ No newline at end of file +endif() diff --git a/include/mlas.h b/include/mlas.h index 28ae64c..40361be 100644 --- a/include/mlas.h +++ b/include/mlas.h @@ -57,13 +57,19 @@ Module Name: #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) || defined(MLAS_TARGET_ARM) #define MLAS_TARGET_ARM_ANY #endif +#if defined(__s390x__) +#define MLAS_TARGET_S390X +#endif #if defined(__VSX__) #define MLAS_TARGET_POWER #endif #if defined(__wasm__) #define MLAS_TARGET_WASM -#if defined(__wasm_simd128__) +#if defined(__wasm_relaxed_simd__) +#define MLAS_TARGET_WASM_RELAXED_SIMD +#define MLAS_TARGET_WASM_SIMD +#elif defined(__wasm_simd128__) #define MLAS_TARGET_WASM_SIMD #else #define MLAS_TARGET_WASM_SCALAR @@ -77,7 +83,7 @@ Module Name: // Define the support levels for the target architecture. // -#if defined(MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_ZVECTOR) #define MLAS_SUPPORTS_GEMM_DOUBLE #endif @@ -628,6 +634,49 @@ MlasGemm( { MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } +/** + * @brief Parameters that define the shape of a dynamically quantized GEMM operation. + * + * The structure holds the dimensions of the matrices involved in the GEMM + * computation: + * C = A * B + */ +struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and Row size of matrix B */ +}; +/** + * @brief Parameters that define the data buffers and layout for a dynamic quant GEMM. + * + * This structure provides the memory pointers and strides for matrices + * involved in a dynamically quantized GEMM operation, along with the packed B format. + */ +struct MLAS_GEMM_DYN_QUANT_DATA_PARAMS { + const float* A = nullptr; /**< Pointer to input matrix A in FP32 format**/ + size_t lda = 0; /**< Number of elements between adjecent rows in A*/ + const void* PackedB = 0; /**< Points to packed weight matrix B */ + float* C = nullptr; /**< Points to output Matric C */ + size_t ldc = 0; /**< Number of elements between adjecent rows in Matrix C*/ + void* Workspace = nullptr; /**< Workspace buffer for LHS Packing Allocation */ + size_t WorkspaceSize = 0; /**< Workspace buffer size */ +}; + +void + MLASCALL + MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool); + +inline void +MlasDynamicQGemm( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool) { + MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool); +} // // Symmetric QGEMM has limited buffer overrun. @@ -680,22 +729,23 @@ MlasSymmQgemmBatch( // size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K - ); + MLASCALL + MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); void -MLASCALL -MlasGemmPackB( - CBLAS_TRANSPOSE TransB, - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ); + MLASCALL + MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); size_t MLASCALL @@ -747,6 +797,22 @@ MlasSymmQgemmPackB( void* PackedB ); +size_t + MLASCALL + MlasDynamicQgemmPackBSize( + size_t N, + size_t K); + +void + MLASCALL + MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB); + // // Convolution routines. // @@ -990,11 +1056,12 @@ MlasComputeErf( size_t N ); +template void MLASCALL MlasComputeExp( - const float* Input, - float* Output, + const T* Input, + T* Output, size_t N ); @@ -1006,73 +1073,61 @@ MlasComputeLogistic( size_t N ); +template void -MLASCALL -MlasComputeSoftmax( - const float* Input, - float* Output, - size_t N, - size_t D, - bool LogSoftmax, - bool SmoothSoftmax, - MLAS_THREADPOOL* ThreadPool - ); - -void -MLASCALL -MlasComputeTanh( - const float* Input, - float* Output, - size_t N - ); - -// -// Transpose routines. -// - + MLASCALL + MlasComputeSoftmax( + const T* Input, + T* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + float Sink, + MLAS_THREADPOOL* ThreadPool); + +template void MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N +MlasComputeSoftcap( + const T* Input, + T* Output, + size_t N, + T cap ); +template void MLASCALL -MlasTranspose( - const int8_t* Input, - int8_t* Output, - size_t M, +MlasEltwiseAdd( + const T* left, + const T* right, + T* output, size_t N ); +template void MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, +MlasComputeTanh( + const T* Input, + T* Output, size_t N ); -void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ); +// +// Transpose routines. +// +template void MLASCALL MlasTranspose( - const float* Input, - float* Output, + const DataType* Input, + DataType* Output, size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ); // @@ -1231,6 +1286,20 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void + MLASCALL + MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias @@ -1427,6 +1496,16 @@ MlasConvertHalfToFloatBuffer( size_t Count ); +#define MLAS_MIN_TENSOR_SIZE_FOR_HALF_TO_FLOAT_CONVERSION_IN_PARALLEL 128000 + +void + MLASCALL + MlasConvertHalfToFloatBufferInParallel( + const MLAS_FP16* Source, + float* Destination, + size_t Count, + MLAS_THREADPOOL* ThreadPool); + void MLASCALL MlasConvertFloatToHalfBuffer( @@ -1435,7 +1514,141 @@ MLAS_FP16* Destination, size_t Count ); - /** +/** + * @brief rotary embedding for one hidden state vector + * + * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported. + * @param input: input tensor, of shape [dim] + * @param sin: sin tensor, of shape [dim/2] + * @param cos: cos tensor, of shape [dim/2] + * @param dim: dimension of rotary embedding + * @param interleaved: whether the real part and imaginary parts are interleaved + * @param output: output tensor, of shape [dim] + */ +template +void +MLASCALL +MlasRotaryEmbedOneRow( + const T* input, + const T* sin_data, + const T* cos_data, + size_t dim, + bool interleaved, + T* output +); + +/** + * @brief Supply matrices data information to half precision gemm functions + */ +struct MLAS_HGEMM_DATA_PARAMS { + const MLAS_FP16* A; /**< Supplies the address of matrix A */ + size_t lda; /**< Supplies the first dimension of matrix A. */ + const MLAS_FP16* B; /**< Supplies the address of matrix B */ + size_t ldb; /**< Supplies the first dimension of matrix B. */ + MLAS_FP16* C; /**< Supplies the address of matrix C */ + size_t ldc; /**< Supplies the first dimension of matrix C. */ + uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */ + uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */ +}; + +/** + * @brief Check whether current CPU supports half precision gemm. + */ +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Check whether mlas supports GQA kernels with the type and transpose settings. + */ +template +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Batched half precision matrix/matrix multiply operation (HGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief half precision matrix/matrix multiply operation (HGEMM) + * C = alpha * op(A) * op(B) + beta * C + * + * @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans. + * @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param alpha Supplies the scalar alpha multiplier (see GEMM definition) + * @param beta Supplies the scalar beta multiplier (see GEMM definition) + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support + * should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_FP16* A, + size_t lda, + const MLAS_FP16* B, + size_t ldb, + MLAS_FP16* C, + size_t ldc, + uint16_t alpha, + uint16_t beta, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_HGEMM_DATA_PARAMS Data; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.C = C; + Data.ldc = ldc; + Data.alpha = alpha; + Data.beta = beta; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1780,20 +1993,22 @@ MlasConvDepthwise( MLAS_HALF_GEMM_POSTPROCESSOR* PostProc ); - inline void MlasTranspose( const MLAS_FP16* Input, MLAS_FP16* Output, size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), - M, N); + M, + N, + ThreadPool); } @@ -1869,3 +2084,14 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return + */ +void + MLASCALL + MlasGemmBatchPackUseKleidi(bool enable); +#endif diff --git a/include/mlas_gemm_postprocessor.h b/include/mlas_gemm_postprocessor.h index 8c24705..7ea29eb 100644 --- a/include/mlas_gemm_postprocessor.h +++ b/include/mlas_gemm_postprocessor.h @@ -16,7 +16,6 @@ Module Name: --*/ #pragma once -#include template class MLAS_GEMM_POSTPROCESSOR diff --git a/include/mlas_q4.h b/include/mlas_q4.h index aec1407..b43f089 100644 --- a/include/mlas_q4.h +++ b/include/mlas_q4.h @@ -266,7 +266,7 @@ MlasBlockwiseQuantizedShape( /** * @brief Compute the sizes of the quantized data and quantization parameter buffers. * - * @param qbits The bit width of each quantized value. + * @tparam qbits The bit width of each quantized value. * @param block_size The number of quantized values in a block. * @param columnwise Whether a block contains values from a matrix column (true) or row (false). * @param rows Number of matrix rows. @@ -277,18 +277,16 @@ MlasBlockwiseQuantizedShape( * * If the qbits or block_size values are unsupported the output sizes will be zero. */ +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, int columns, size_t& q_data_size_in_bytes, size_t& q_scale_num_elements, - size_t* q_zero_point_size_in_bytes -); - + size_t* q_zero_point_size_in_bytes); /** * @brief Blockwise 4 bits quantization, resulting elements and quantization diff --git a/include/mlas_qnbit.h b/include/mlas_qnbit.h index 232bf22..2a1e1fc 100644 --- a/include/mlas_qnbit.h +++ b/include/mlas_qnbit.h @@ -27,51 +27,61 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ +} MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. + * + * @tparam T data type of input A */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) +template +struct MLAS_QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + + /// + /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively + /// an individual block's scale, quantized values, and zero point for the input `B`. + /// When converting the activation input (A) to uint8, we first convert the values to int8 and then + /// add a "bias" of +128 to convert the range of values from [-128, +127] to [0, +255]. + /// This input helps to "de-bias" the output of the +128 bias added to the activation input. + /// This input is to be used only when A is quantized to uint8. + /// + const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; + + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; /** * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix + * A must be a float32/16 matrix * B must be a quantized and packed n-bit int matrix * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). + * Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasQNBitGemmPackQuantBData(). * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should * point to an intermediate workspace buffer. * + * @tparam T data type of input A * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -81,36 +91,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform. * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -123,39 +134,43 @@ MlasIsSQNBitGemmAvailable( * @param[in] BatchN number of batches * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). + * If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch(). + * If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasQNBitGemmBatch(). * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -181,21 +196,39 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] QuantBData quantized B data * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum * @param[in] QuantBScale quantized B scale - * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] HasZeroPoint whether QuantBZeroPoint is provided * @param[in] QuantBZeroPoint quantized B zero point * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, - bool has_zp_input, + bool HasZeroPoint, const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ); + +/** + * @brief Returns true if scales are packed when calling MlasQNBitGemmPackQuantBData the first time. + * + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] HasZeroPoint whether QuantBZeroPoint is provided + */ +bool MLASCALL +MlasQNBitGemmScalesPacked( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + bool HasZeroPoint +); diff --git a/src/lib/sqnbitgemm.h b/include/qnbitgemm.h similarity index 53% rename from src/lib/sqnbitgemm.h rename to include/qnbitgemm.h index 2da336c..4c13310 100644 --- a/src/lib/sqnbitgemm.h +++ b/include/qnbitgemm.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.h + qnbitgemm.h Abstract: @@ -46,24 +46,26 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize - constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - - // _mm256_load_si256 requires alignment on a 32-byte boundary - PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); +#if defined(MLAS_TARGET_AMD64_IX86) + // avx512 requires alignment on a 64-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#else + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; +#endif + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); } std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; + T* PackedQuantBScale; + T* QuantBBlkSum; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -84,49 +86,77 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) // Kernel dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH { +struct MLAS_QNBIT_GEMM_DISPATCH { // // Quantized B data packing function prototypes. // - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q8BitGemmPackQuantBDataSize_Fn* Q8BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ); SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + typedef void(SQ8BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ8BitGemmPackQuantBDataAndSumBlk_Fn* SQ8BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -139,17 +169,19 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -157,15 +189,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + QNBitGemmPerGemmWorkspaceAlignment_Fn* QNBitGemmPerGemmWorkspaceAlignment = nullptr; // - // CompFp32 kernel function prototypes. + // SQNBIT_CompFp32 kernel function prototypes. // /** @@ -228,12 +260,76 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; // - // CompInt8 kernel function prototypes. + // SQNBIT_CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * A should be packed using QuantizeA_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param PackedQuantBData Supplies the packed quantized B matrix data. + * @param[out] C Supplies the output C matrix. + * @param RangeStartM Start of M range. + * @param RangeCountM Number of rows of A and C. + * @param RangeStartN Start of N range. + * @param RangeCountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param ldc Number of elements between adjacent rows of C. + */ + typedef void(SQ4BitGemmKernel_Packed_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_Packed_CompInt8_Fn* SQ4BitGemmKernel_Packed_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -273,6 +369,45 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -310,6 +445,38 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + /** + * @brief Whether to use SQ4BitGemmKernel_Packed_CompInt8 for this problem. + */ + typedef bool(UsePacked_CompInt8_Fn)( + size_t K, + size_t BlkLen, + bool HasZp + ); + + UsePacked_CompInt8_Fn* UsePacked_CompInt8 = nullptr; + + /** + * @brief Block quantize values from matrix A from floats to quantized 8-bit integers. + * Used in conjunction with SQ4BitGemmKernel_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeA_Packed_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA + ); + + QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * @@ -337,4 +504,35 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; }; diff --git a/src/common/cpuid_info.cc b/src/common/cpuid_info.cc index 04172e4..0d996a0 100644 --- a/src/common/cpuid_info.cc +++ b/src/common/cpuid_info.cc @@ -1,11 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/common/cpuid_info.h" + +#include +#include + #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/platform/check_intel.h" #ifdef __linux__ - +#if (defined(_M_AMD64) || defined(__x86_64__)) && !defined(__ANDROID__) +#include +#endif #include #include #if !defined(__NR_getcpu) @@ -22,6 +29,10 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif + #ifndef HWCAP2_I8MM #define HWCAP2_I8MM (1 << 13) #endif @@ -42,14 +53,20 @@ #include -#define HAS_WINDOWS_DESKTOP WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) - #ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE #define PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE 43 #endif #endif // _WIN32 +#if defined(__APPLE__) +#if defined(CPUIDINFO_ARCH_ARM) + +#include + +#endif // defined(CPUIDINFO_ARCH_ARM) +#endif // defined(__APPLE__) + #if defined(CPUINFO_SUPPORTED) #include #if defined(CPUIDINFO_ARCH_ARM) @@ -73,6 +90,14 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]); namespace onnxruntime { +void CPUIDInfo::LogEarlyWarning(std::string_view message) { + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(WARNING) << message; + } else { + std::cerr << "onnxruntime cpuid_info warning: " << message << "\n"; + } +} + #if defined(CPUIDINFO_ARCH_X86) static inline void GetCPUID(int function_id, int data[4]) { // NOLINT @@ -123,7 +148,9 @@ void CPUIDInfo::X86Init() { has_f16c_ = has_avx_ && (data[2] & (1 << 29)) && (data[3] & (1 << 26)); if (num_IDs >= 7) { - GetCPUID(7, data); + // This change is made to overcome the issue of __get_cpuid returning all zeros, instead use __get_cpuid_count. + // Reference: https://stackoverflow.com/questions/46272579/why-does-get-cpuid-return-all-zeros-for-leaf-4 + GetCPUID(7, 0, data); const uint32_t max_SubLeaves = data[0]; has_amx_bf16_ = (data[3] & (1 << 22)); has_avx2_ = has_avx_ && (data[1] & (1 << 5)); @@ -132,6 +159,7 @@ void CPUIDInfo::X86Init() { // avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31))); is_hybrid_ = (data[3] & (1 << 15)); + if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); @@ -155,8 +183,12 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + // SVE is enabled only on Linux-based ARM CPUs for now, where it has been tested. + has_arm_sve_ = cpuinfo_has_arm_sve(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -185,6 +217,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ |= has_arm_neon_dot_; has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_ = ((getauxval(AT_HWCAP) & HWCAP_SVE) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); @@ -194,21 +227,22 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) // ^ defined(__linux__) void CPUIDInfo::ArmWindowsInit() { -// ARM32 certainly doesn't have fp16, so we will skip the logic to avoid using RegGetValueA Windows API -#if !defined(_M_ARM) -#pragma region Application Family or OneCore Family -#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) - // Read MIDR from windows registry + // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry + // There should be one per CPU + std::vector midr_values{}, id_aa64isar1_el1_values{}; + // TODO!! Don't support multiple processor group yet!! constexpr int MAX_CORES = 64; constexpr int MAX_VALUE_NAME = 4096; - CHAR midrKey[MAX_VALUE_NAME] = ""; // buffer for processor registry name - uint32_t lastUarch = cpuinfo_uarch_unknown; - for (int i = 0; i < MAX_CORES - 1; i++) { - snprintf(midrKey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", i); - uint64_t midrVal; - unsigned long midrSize = sizeof(uint64_t); + CHAR processor_subkey[MAX_VALUE_NAME] = ""; // buffer for processor registry name + + for (size_t i = 0; i < MAX_CORES - 1; i++) { + snprintf(processor_subkey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", + static_cast(i)); + + uint64_t midr_value; + unsigned long data_size = sizeof(midr_value); /* * ARM lists for each coprocessor register 5 fields: op0/op1/CRn/CRm/op2. @@ -223,48 +257,65 @@ void CPUIDInfo::ArmWindowsInit() { * * For the CP value of MIDR, op0 = 3 and the others are all = 0, so we come up with 0x4000, */ - auto retCode = ::RegGetValueA(HKEY_LOCAL_MACHINE, midrKey, "CP 4000", RRF_RT_REG_QWORD, nullptr, &midrVal, &midrSize); - if (retCode != ERROR_SUCCESS) { + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4000", RRF_RT_REG_QWORD, + nullptr, &midr_value, &data_size) != ERROR_SUCCESS) { break; } - uint32_t uarch = cpuinfo_uarch_unknown; - decodeMIDR((uint32_t)midrVal, &uarch); - core_uarchs_.push_back(uarch); - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_.push_back(true); - } else { - is_armv8_narrow_ld_.push_back(false); + + uint64_t id_aa64isar1_el1_value; + data_size = sizeof(id_aa64isar1_el1_value); + + // CP 4031 corresponds to ID_AA64ISAR1_EL1 register + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4031", RRF_RT_REG_QWORD, + nullptr, &id_aa64isar1_el1_value, &data_size) != ERROR_SUCCESS) { + break; } - if (i == 0) { - lastUarch = uarch; - } else if (lastUarch != uarch) { - is_hybrid_ = true; - lastUarch = uarch; + midr_values.push_back(midr_value); + id_aa64isar1_el1_values.push_back(id_aa64isar1_el1_value); + } + + // process midr_values + { + uint32_t lastUarch = cpuinfo_uarch_unknown; + for (size_t i = 0; i < midr_values.size(); ++i) { + uint32_t uarch = cpuinfo_uarch_unknown; + decodeMIDR(static_cast(midr_values[i]), &uarch); + core_uarchs_.push_back(uarch); + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_.push_back(true); + } else { + is_armv8_narrow_ld_.push_back(false); + } + + if (i == 0) { + lastUarch = uarch; + } else if (lastUarch != uarch) { + is_hybrid_ = true; + lastUarch = uarch; + } } } -#endif // WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) + + has_arm_neon_i8mm_ = std::all_of( + id_aa64isar1_el1_values.begin(), id_aa64isar1_el1_values.end(), + [](uint64_t id_aa64isar1_el1_value) { + // I8MM, bits [55:52] + return ((id_aa64isar1_el1_value >> 52) & 0xF) != 0; + }); has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else // ^ !defined(_M_ARM) / v defined(_M_ARM) - has_arm_neon_dot_ = false; -#endif // defined(_M_ARM) #if defined(CPUINFO_SUPPORTED) if (pytorch_cpuinfo_init_) { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); - has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + // cpuinfo_has_arm_i8mm() doesn't work on Windows yet. See https://github.com/pytorch/cpuinfo/issues/279. + // has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && has_arm_neon_i8mm_; has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); - } else -#endif // defined(CPUINFO_SUPPORTED) - { - has_fp16_ = false; - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; } +#endif // defined(CPUINFO_SUPPORTED) } #elif defined(__APPLE__) // ^ defined(_WIN32) @@ -278,6 +329,8 @@ void CPUIDInfo::ArmAppleInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. @@ -310,16 +363,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { } CPUIDInfo::CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) #if defined(CPUINFO_SUPPORTED) pytorch_cpuinfo_init_ = cpuinfo_initialize(); if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " - "due to undetected CPU features."; + LogEarlyWarning( + "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation due to undetected CPU " + "features."); } #endif // defined(CPUINFO_SUPPORTED) + + // Note: This should be run after cpuinfo initialization if cpuinfo is enabled. + VendorInfoInit(); + +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) ArmLinuxInit(); #elif defined(_WIN32) diff --git a/src/common/cpuid_info_vendor.cc b/src/common/cpuid_info_vendor.cc new file mode 100644 index 0000000..d4d940e --- /dev/null +++ b/src/common/cpuid_info_vendor.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/cpuid_info.h" + +#include +#include +#include + +#if defined(CPUINFO_SUPPORTED) +#include "cpuinfo.h" +#endif + +namespace { + +#if !defined(CPUINFO_SUPPORTED) + +// The `cpuinfo_vendor` enum is defined by the cpuinfo library. +// In case we don't build with cpuinfo, we define our own copy. +// The enum was copied from here: +// https://github.com/pytorch/cpuinfo/blob/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6/include/cpuinfo.h#L154-L307 + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +#endif // !defined(CPUINFO_SUPPORTED) + +} // namespace + +namespace onnxruntime { + +namespace { + +struct CpuVendorInfo { + cpuinfo_vendor vendor; + std::string_view name; + uint32_t id; +}; + +constexpr auto kUnknownCpuVendorInfo = CpuVendorInfo{cpuinfo_vendor_unknown, "unknown", 0x0000}; + +constexpr std::array kCpuVendorInfos{ + CpuVendorInfo{cpuinfo_vendor_amd, "AMD", 0x1022}, + CpuVendorInfo{cpuinfo_vendor_intel, "Intel", 0x8086}, + CpuVendorInfo{cpuinfo_vendor_qualcomm, "Qualcomm", uint32_t{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}}, + CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, + CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, + CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, + + // TODO add more as needed +}; + +const CpuVendorInfo* FindCpuVendorInfo(cpuinfo_vendor vendor) { + const auto vendor_mapping_it = std::find_if(kCpuVendorInfos.begin(), kCpuVendorInfos.end(), + [vendor](const CpuVendorInfo& entry) { + return entry.vendor == vendor; + }); + + if (vendor_mapping_it != kCpuVendorInfos.end()) { + return &*vendor_mapping_it; + } + + return nullptr; +} + +} // namespace + +void CPUIDInfo::VendorInfoInit() { + const cpuinfo_vendor vendor = [&]() { + cpuinfo_vendor result = cpuinfo_vendor_unknown; +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + const auto* processor = cpuinfo_get_processor(0); + if (processor && processor->core) { + result = processor->core->vendor; + } + } +#endif // defined(CPUINFO_SUPPORTED) + return result; + }(); + + const auto* vendor_info = FindCpuVendorInfo(vendor); + if (vendor_info == nullptr) { + LogEarlyWarning(MakeString("Unknown CPU vendor. cpuinfo_vendor value: ", static_cast(vendor))); + vendor_info = &kUnknownCpuVendorInfo; + } + + vendor_ = vendor_info->name; + vendor_id_ = vendor_info->id; +} + +} // namespace onnxruntime diff --git a/src/common/cpuid_uarch.cc b/src/common/cpuid_uarch.cc index 16634b2..28a3524 100644 --- a/src/common/cpuid_uarch.cc +++ b/src/common/cpuid_uarch.cc @@ -30,9 +30,11 @@ inline static uint32_t midr_get_part(uint32_t midr) { return (midr & CPUINFO_ARM_MIDR_PART_MASK) >> CPUINFO_ARM_MIDR_PART_OFFSET; } +#if 0 inline static uint32_t midr_get_variant(uint32_t midr) { return (midr & CPUINFO_ARM_MIDR_VARIANT_MASK) >> CPUINFO_ARM_MIDR_VARIANT_OFFSET; } +#endif void decodeMIDR( uint32_t midr, @@ -137,8 +139,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_arm11; break; // #endif /* ARM */ - default: - std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,8 +158,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; // #endif - default: - std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,8 +174,8 @@ void decodeMIDR( case 0x0AF: /* ThunderX2 99XX */ *uarch = cpuinfo_uarch_thunderx2; break; - default: - std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,8 +189,8 @@ void decodeMIDR( case 0xD40: /* Kirin 980 Big/Medium cores -> Cortex-A76 */ *uarch = cpuinfo_uarch_cortex_a76; break; - default: - std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,8 +201,8 @@ void decodeMIDR( case 6: /* PXA 3XX */ *uarch = cpuinfo_uarch_xscale; break; - default: - std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,8 +217,8 @@ void decodeMIDR( case 0x004: *uarch = cpuinfo_uarch_carmel; break; - default: - std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,8 +227,8 @@ void decodeMIDR( case 0x000: *uarch = cpuinfo_uarch_xgene; break; - default: - std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,8 +299,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_saphira; break; // #endif /* ARM64 && !defined(__ANDROID__) */ - default: - std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,10 +345,10 @@ void decodeMIDR( */ *uarch = cpuinfo_uarch_exynos_m5; break; - default: - std::cerr << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) - << " ignored\n"; + // default: + // std::cerr << "unknown Samsung CPU variant 0x" + //<< std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + //<< " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -356,13 +358,13 @@ void decodeMIDR( case 0x584: /* PJ4B-MP / PJ4C */ *uarch = cpuinfo_uarch_pj4; break; - default: - std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ - default: - std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; + // default: + // std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/src/common/helper.cc b/src/common/helper.cc index 6a52db7..07cd167 100644 --- a/src/common/helper.cc +++ b/src/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/src/common/logging/logging.cc b/src/common/logging/logging.cc index a326095..4c94a29 100644 --- a/src/common/logging/logging.cc +++ b/src/common/logging/logging.cc @@ -63,13 +63,13 @@ LoggingManager* LoggingManager::GetDefaultInstance() { #pragma warning(disable : 26426) #endif -static OrtMutex& DefaultLoggerMutex() noexcept { - static OrtMutex mutex; +static std::mutex& DefaultLoggerMutex() noexcept { + static std::mutex mutex; return mutex; } Logger* LoggingManager::s_default_logger_ = nullptr; -OrtMutex sink_mutex_; +std::mutex sink_mutex_; #ifdef _MSC_VER #pragma warning(pop) @@ -106,7 +106,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min // lock mutex to create instance, and enable logging // this matches the mutex usage in Shutdown - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); if (DefaultLoggerManagerInstance().load() != nullptr) { ORT_THROW("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time."); @@ -126,7 +126,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min LoggingManager::~LoggingManager() { if (owns_default_logger_) { // lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance. - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); #if ((__cplusplus >= 201703L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L))) DefaultLoggerManagerInstance().store(nullptr, std::memory_order_release); #else @@ -249,10 +249,9 @@ unsigned int GetProcessId() { #endif } - bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory, logging::Severity severity) { - std::lock_guard guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() != SinkType::CompositeSink) { // Current sink is not a composite, create a new composite sink and add the current sink to it auto new_composite = std::make_unique(); @@ -274,7 +273,7 @@ bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() == SinkType::CompositeSink) { auto composite_sink = static_cast(sink_.get()); diff --git a/src/common/profiler.cc b/src/common/profiler.cc index 71bca6e..8562e55 100644 --- a/src/common/profiler.cc +++ b/src/common/profiler.cc @@ -85,7 +85,7 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category, custom_logger_->SendProfileEvent(event); } else { // TODO: sync_gpu if needed. - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); if (events_.size() < max_num_events_) { events_.emplace_back(std::move(event)); } else { @@ -115,7 +115,7 @@ std::string Profiler::EndProfiling() { LOGS(*session_logger_, INFO) << "Writing profiler data to file " << profile_stream_file_; } - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); profile_stream_ << "[\n"; for (const auto& ep_profiler : ep_profilers_) { diff --git a/src/common/profiler.h b/src/common/profiler.h index a0bca00..0103d8a 100644 --- a/src/common/profiler.h +++ b/src/common/profiler.h @@ -11,7 +11,7 @@ #include "core/common/profiler_common.h" #include "core/common/logging/logging.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -130,7 +130,7 @@ class Profiler { static std::atomic global_max_num_events_; // Mutex controlling access to profiler data - OrtMutex mutex_; + std::mutex mutex_; bool enabled_{false}; #if defined(__wasm__) /* diff --git a/src/common/safeint.h b/src/common/safeint.h index 3ee70f3..6aba587 100644 --- a/src/common/safeint.h +++ b/src/common/safeint.h @@ -13,11 +13,11 @@ class SafeIntExceptionHandler; template <> class SafeIntExceptionHandler { public: - static void SafeIntOnOverflow() { + [[noreturn]] static void SafeIntOnOverflow() { ORT_THROW("Integer overflow"); } - static void SafeIntOnDivZero() { + [[noreturn]] static void SafeIntOnDivZero() { ORT_THROW("Divide by zero"); } }; diff --git a/src/common/semver.h b/src/common/semver.h new file mode 100644 index 0000000..98bb6a2 --- /dev/null +++ b/src/common/semver.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" + +namespace onnxruntime { + +// Semantic Versioning version utilities. +// See https://github.com/semver/semver/blob/v2.0.0/semver.md. + +// Semantic Versioning version components. +struct SemVerVersion { + uint32_t major{}; + uint32_t minor{}; + uint32_t patch{}; + std::optional prerelease{}; + std::optional build_metadata{}; +}; + +// Parse a Semantic Versioning version from `version_string`. +// If provided, the parsed version components will be written to `semver_version`. +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version); + +// Parse a Semantic Versioning version from `version_string`. +SemVerVersion ParseSemVerVersion(std::string_view version_string); + +} // namespace onnxruntime diff --git a/src/common/spin_pause.cc b/src/common/spin_pause.cc new file mode 100644 index 0000000..7b1de8d --- /dev/null +++ b/src/common/spin_pause.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/spin_pause.h" + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +#if defined(_M_AMD64) || defined(__x86_64__) +#include "core/common/cpuid_info.h" +#if defined(__linux__) +#include +#include +#endif +#endif + +namespace onnxruntime { +namespace concurrency { + +// Intrinsic to use in spin-loops +void SpinPause() { +#if (defined(_M_AMD64) || defined(__x86_64__)) && \ + !defined(_M_ARM64EC) && \ + !defined(__ANDROID__) && \ + !defined(__APPLE__) + _mm_pause(); +#endif +} + +} // namespace concurrency +} // namespace onnxruntime diff --git a/src/common/string_utils.h b/src/common/string_utils.h index 716eed1..d8d943d 100644 --- a/src/common/string_utils.h +++ b/src/common/string_utils.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include #include #include @@ -59,10 +61,11 @@ inline void TrimStringFromRight(std::string& s) { * @param s The string to trim. * @return The trimmed string. */ -inline std::string TrimString(std::string s) { - TrimStringFromRight(s); - TrimStringFromLeft(s); - return s; +inline std::string TrimString(std::string_view s) { + std::string s_trimmed{s}; + TrimStringFromRight(s_trimmed); + TrimStringFromLeft(s_trimmed); + return s_trimmed; } /** @@ -84,5 +87,21 @@ inline uint32_t GetHashFromString(const std::string& str_value) { return hash; } +/** + * Returns a lowercase version of the input string. + * @param str The string to lowercase. + * @return The lowercased string. + */ +inline std::string GetLowercaseString(std::string str) { + // https://en.cppreference.com/w/cpp/string/byte/tolower + // The behavior of tolower from is undefined if the argument is neither representable as unsigned char + // nor equal to EOF. To use tolower safely with a plain char (or signed char), the argument must be converted to + // unsigned char. + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return str; +} + } // namespace utils } // namespace onnxruntime diff --git a/src/common/threadpool.cc b/src/common/threadpool.cc index 52d1c1e..0cfdf08 100644 --- a/src/common/threadpool.cc +++ b/src/common/threadpool.cc @@ -21,11 +21,11 @@ limitations under the License. #include "core/common/cpuid_info.h" #include "core/common/eigen_common_wrapper.h" #include "core/platform/EigenNonBlockingThreadPool.h" -#include "core/platform/ort_mutex.h" +#include #if !defined(ORT_MINIMAL_BUILD) #ifdef _WIN32 #include -#include +#include "processthreadsapi.h" #include #include #elif defined(__APPLE__) @@ -46,202 +46,6 @@ namespace onnxruntime { namespace concurrency { -#if !defined(ORT_MINIMAL_BUILD) -ThreadPoolProfiler::ThreadPoolProfiler(int num_threads, const CHAR_TYPE* thread_pool_name) : num_threads_(num_threads) { - child_thread_stats_.assign(num_threads, {}); - if (thread_pool_name) { -#ifdef _WIN32 - thread_pool_name_ = ToUTF8String(thread_pool_name); -#else - thread_pool_name_ = thread_pool_name; -#endif - } else { - thread_pool_name_ = "unnamed_thread_pool"; - } -} - -ThreadPoolProfiler::~ThreadPoolProfiler() { - enabled_ = false; -} - -void ThreadPoolProfiler::Start() { - enabled_ = true; -} - -ThreadPoolProfiler::MainThreadStat& ThreadPoolProfiler::GetMainThreadStat() { - static thread_local std::unique_ptr stat; - if (!stat) { - stat = std::make_unique(); - } - return *stat; -} - -std::string ThreadPoolProfiler::Stop() { - ORT_ENFORCE(enabled_, "Profiler not started yet"); - std::ostringstream ss; - ss << "{\"main_thread\": {" - << "\"thread_pool_name\": \"" - << thread_pool_name_ << "\", " - << GetMainThreadStat().Reset() - << "}, \"sub_threads\": {" - << DumpChildThreadStat() - << "}}"; - return ss.str(); -} - -void ThreadPoolProfiler::LogStartAndCoreAndBlock(std::ptrdiff_t block_size) { - if (enabled_) { - MainThreadStat& stat = GetMainThreadStat(); - stat.LogCore(); - stat.LogBlockSize(block_size); - stat.LogStart(); - } -} - -void ThreadPoolProfiler::LogCoreAndBlock(std::ptrdiff_t block_size) { - if (enabled_) { - MainThreadStat& stat = GetMainThreadStat(); - stat.LogCore(); - stat.LogBlockSize(block_size); - } -} - -void ThreadPoolProfiler::LogStart() { - if (enabled_) { - GetMainThreadStat().LogStart(); - } -} - -void ThreadPoolProfiler::LogEnd(ThreadPoolEvent evt) { - if (enabled_) { - GetMainThreadStat().LogEnd(evt); - } -} - -void ThreadPoolProfiler::LogEndAndStart(ThreadPoolEvent evt) { - if (enabled_) { - GetMainThreadStat().LogEndAndStart(evt); - } -} - -void ThreadPoolProfiler::MainThreadStat::LogCore() { -#ifdef _WIN32 - core_ = GetCurrentProcessorNumber(); -#elif defined(__APPLE__) -#if defined(__x86_64__) || defined(__i386__) - uint32_t CPUInfo[4]; - __cpuid_count(1, 0, CPUInfo[0], CPUInfo[1], CPUInfo[2], CPUInfo[3]); - if ((CPUInfo[3] & (1 << 9)) != 0) { - core_ = (unsigned)CPUInfo[1] >> 24; - } -#endif -#elif defined(__wasm__) - core_ = emscripten_num_logical_cores(); -#elif defined(_AIX) - core_ = mycpu(); -#else - core_ = sched_getcpu(); -#endif -} - -void ThreadPoolProfiler::MainThreadStat::LogBlockSize(std::ptrdiff_t block_size) { - blocks_.emplace_back(block_size); -} - -void ThreadPoolProfiler::MainThreadStat::LogStart() { - points_.emplace_back(Clock::now()); -} - -void ThreadPoolProfiler::MainThreadStat::LogEnd(ThreadPoolEvent evt) { - ORT_ENFORCE(!points_.empty(), "LogStart must pair with LogEnd"); - events_[evt] += TimeDiffMicroSeconds(points_.back(), Clock::now()); - points_.pop_back(); -} - -void ThreadPoolProfiler::MainThreadStat::LogEndAndStart(ThreadPoolEvent evt) { - ORT_ENFORCE(!points_.empty(), "LogStart must pair with LogEnd"); - events_[evt] += TimeDiffMicroSeconds(points_.back(), Clock::now()); - points_.back() = Clock::now(); -} - -std::string ThreadPoolProfiler::MainThreadStat::Reset() { - ORT_ENFORCE(points_.empty(), "LogStart must pair with LogEnd"); - std::stringstream ss; - ss << "\"thread_id\": \"" << std::this_thread::get_id() << "\", \"block_size\": ["; - if (!blocks_.empty()) { - std::copy(blocks_.begin(), blocks_.end() - 1, std::ostream_iterator(ss, ", ")); - ss << blocks_.back(); - blocks_.clear(); - } - ss << "], \"core\": " << core_ << ", "; - for (int i = 0; i < MAX_EVENT; ++i) { - ss << "\"" << ThreadPoolProfiler::GetEventName(static_cast(i)) - << "\": " << events_[i] << ((i == MAX_EVENT - 1) ? std::string{} : ", "); - } - memset(events_, 0, sizeof(uint64_t) * MAX_EVENT); - return ss.str(); -} - -const char* ThreadPoolProfiler::GetEventName(ThreadPoolEvent event) { - switch (event) { - case DISTRIBUTION: - return "Distribution"; - case DISTRIBUTION_ENQUEUE: - return "DistributionEnqueue"; - case RUN: - return "Run"; - case WAIT: - return "Wait"; - case WAIT_REVOKE: - return "WaitRevoke"; - default: - return "UnknownEvent"; - } -} - -void ThreadPoolProfiler::LogThreadId(int thread_idx) { - child_thread_stats_[thread_idx].thread_id_ = std::this_thread::get_id(); -} - -void ThreadPoolProfiler::LogRun(int thread_idx) { - if (enabled_) { - child_thread_stats_[thread_idx].num_run_++; - auto now = Clock::now(); - if (child_thread_stats_[thread_idx].core_ < 0 || - TimeDiffMicroSeconds(child_thread_stats_[thread_idx].last_logged_point_, now) > 10000) { -#ifdef _WIN32 - child_thread_stats_[thread_idx].core_ = GetCurrentProcessorNumber(); -#elif defined(__APPLE__) -#if defined(__x86_64__) || defined(__i386__) - uint32_t CPUInfo[4]; - __cpuid_count(1, 0, CPUInfo[0], CPUInfo[1], CPUInfo[2], CPUInfo[3]); - if ((CPUInfo[3] & (1 << 9)) != 0) { - child_thread_stats_[thread_idx].core_ = (unsigned)CPUInfo[1] >> 24; - } -#endif -#elif defined(__wasm__) - child_thread_stats_[thread_idx].core_ = emscripten_num_logical_cores(); -#elif defined(_AIX) - child_thread_stats_[thread_idx].core_ = mycpu(); -#else - child_thread_stats_[thread_idx].core_ = sched_getcpu(); -#endif - child_thread_stats_[thread_idx].last_logged_point_ = now; - } - } -} - -std::string ThreadPoolProfiler::DumpChildThreadStat() { - std::stringstream ss; - for (int i = 0; i < num_threads_; ++i) { - ss << "\"" << child_thread_stats_[i].thread_id_ << "\": {" - << "\"num_run\": " << child_thread_stats_[i].num_run_ << ", " - << "\"core\": " << child_thread_stats_[i].core_ << "}" - << (i == num_threads_ - 1 ? "" : ","); - } - return ss.str(); -} -#endif // A sharded loop counter distributes loop iterations between a set of worker threads. The iteration space of // the loop is divided (perhaps unevenly) between the shards. Each thread has a home shard (perhaps not uniquely @@ -439,7 +243,7 @@ void ThreadPool::ParallelForFixedBlockSizeScheduling(const std::ptrdiff_t total, // threads is handled within RunInParallel, hence we can deallocate lc and other state captured by // run_work. RunInParallel(run_work, num_work_items, block_size); - } + } } void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function& fn) { @@ -458,19 +262,7 @@ void ThreadPool::Schedule(std::function fn) { } } -void ThreadPool::StartProfiling() { - if (underlying_threadpool_) { - underlying_threadpool_->StartProfiling(); - } -} -std::string ThreadPool::StopProfiling() { - if (underlying_threadpool_) { - return underlying_threadpool_->StopProfiling(); - } else { - return {}; - } -} namespace { thread_local std::optional current_parallel_section; @@ -628,19 +420,6 @@ int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) { } } -void ThreadPool::StartProfiling(concurrency::ThreadPool* tp) { - if (tp) { - tp->StartProfiling(); - } -} - -std::string ThreadPool::StopProfiling(concurrency::ThreadPool* tp) { - if (tp) { - return tp->StopProfiling(); - } else { - return {}; - } -} void ThreadPool::EnableSpinning() { if (extended_eigen_threadpool_) { diff --git a/src/core/platform/check_intel.h b/src/core/platform/check_intel.h new file mode 100644 index 0000000..1b82940 --- /dev/null +++ b/src/core/platform/check_intel.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +typedef struct { + bool is_intel; + bool is_intel_specified_platform; +} CheckIntelResult; + +CheckIntelResult CheckIntel(); +} // namespace onnxruntime diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index b956268..9e33bcc 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -62,28 +62,13 @@ namespace { constexpr int OneMillion = 1000000; -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - int ret = munmap(p->addr, p->len); - if (ret != 0) { - auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(ERROR) << "munmap failed. error code: " << err_no << " error msg: " << err_msg; - } -} - struct FileDescriptorTraits { using Handle = int; static Handle GetInvalidHandleValue() { return -1; } static void CleanUp(Handle h) { if (close(h) == -1) { auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(ERROR) << "Failed to close file descriptor " << h << " - error code: " << err_no + std::cout << "Failed to close file descriptor " << h << " - error code: " << err_no << " error msg: " << err_msg; } } @@ -104,19 +89,6 @@ long int TempFailureRetry(TFunc retriable_operation, TFuncArgs&&... args) { return result; } -// nftw() callback to remove a file -int nftw_remove( - const char* fpath, const struct stat* /*sb*/, - int /*typeflag*/, struct FTW* /*ftwbuf*/) { - const auto result = remove(fpath); - if (result != 0) { - auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(WARNING) << "remove() failed. Error code: " << err_no << " error msg: " << err_msg - << ", path: " << fpath; - } - return result; -} - template struct Freer { void operator()(T* p) { ::free(p); } @@ -148,14 +120,13 @@ class PosixThread : public EnvThread { unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); - - + auto param_ptr = std::make_unique(name_prefix, index, start_address, param); if (narrow(index) < thread_options.affinities.size()) { param_ptr->affinity = thread_options.affinities[index]; } - { + { pthread_attr_t attr; int s = pthread_attr_init(&attr); if (s != 0) { @@ -183,7 +154,7 @@ class PosixThread : public EnvThread { } ~PosixThread() override { - { + { void* res; #ifdef NDEBUG pthread_join(hThread, &res); @@ -208,19 +179,19 @@ class PosixThread : public EnvThread { } else { // Logical processor id starts from 0 internally, but in ort API, it starts from 1, // that's why id need to increase by 1 when logging. - LOGS_DEFAULT(ERROR) << "cpu " << id + 1 << " does not exist, skipping it for affinity setting"; + std::cout << "cpu " << id + 1 << " does not exist, skipping it for affinity setting"; } } auto ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (0 == ret) { - LOGS_DEFAULT(VERBOSE) << "pthread_setaffinity_np succeed for thread: " << syscall(SYS_gettid) + std::cout << "pthread_setaffinity_np succeed for thread: " << syscall(SYS_gettid) << ", index: " << p->index << ", mask: " << *p->affinity; } else { errno = ret; auto [err_no, err_msg] = GetErrnoInfo(); #if !defined(USE_MIGRAPHX) - LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid) + std::cout << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid) << ", index: " << p->index << ", mask: " << *p->affinity << ", error code: " << err_no << " error msg: " << err_msg @@ -313,105 +284,12 @@ class PosixEnv : public Env { #endif } - void SleepForMicroseconds(int64_t micros) const override { - while (micros > 0) { - timespec sleep_time; - sleep_time.tv_sec = 0; - sleep_time.tv_nsec = 0; - - if (micros >= OneMillion) { - sleep_time.tv_sec = static_cast(std::min(micros / OneMillion, - std::numeric_limits::max())); - micros -= static_cast(sleep_time.tv_sec) * OneMillion; - } - if (micros < OneMillion) { - sleep_time.tv_nsec = static_cast(1000 * micros); - micros = 0; - } - while (nanosleep(&sleep_time, &sleep_time) != 0 && errno == EINTR) { - // Ignore signals and wait for the full interval to elapse. - } - } - } PIDType GetSelfPid() const override { return getpid(); } - Status GetFileLength(const PathChar* file_path, size_t& length) const override { - ScopedFileDescriptor file_descriptor{open(file_path, O_RDONLY)}; - return GetFileLength(file_descriptor.Get(), length); - } - - common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override { - using namespace common; - if (fd < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); - } - - struct stat buf; - int rc = fstat(fd, &buf); - if (rc < 0) { - return ReportSystemError("fstat", ""); - } - - if (buf.st_size < 0) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); - } - - if (static_cast(buf.st_size) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); - } - - file_size = static_cast(buf.st_size); - return Status::OK(); - } - - Status ReadFileIntoBuffer(const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - gsl::span buffer) const override { - ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); - ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); - ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); - - ScopedFileDescriptor file_descriptor{open(file_path, O_RDONLY)}; - if (!file_descriptor.IsValid()) { - return ReportSystemError("open", file_path); - } - - if (length == 0) - return Status::OK(); - - if (offset > 0) { - const FileOffsetType seek_result = lseek(file_descriptor.Get(), offset, SEEK_SET); - if (seek_result == -1) { - return ReportSystemError("lseek", file_path); - } - } - - size_t total_bytes_read = 0; - while (total_bytes_read < length) { - constexpr size_t k_max_bytes_to_read = 1 << 30; // read at most 1GB each time - const size_t bytes_remaining = length - total_bytes_read; - const size_t bytes_to_read = std::min(bytes_remaining, k_max_bytes_to_read); - - const ssize_t bytes_read = - TempFailureRetry(read, file_descriptor.Get(), buffer.data() + total_bytes_read, bytes_to_read); - - if (bytes_read == -1) { - return ReportSystemError("read", file_path); - } - - if (bytes_read == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFileIntoBuffer - unexpected end of file. ", "File: ", file_path, - ", offset: ", offset, ", length: ", length); - } - - total_bytes_read += bytes_read; - } - - return Status::OK(); - } - + static common::Status ReportSystemError(const char* operation_name, const std::string& path) { auto [err_no, err_msg] = GetErrnoInfo(); @@ -420,69 +298,6 @@ class PosixEnv : public Env { return common::Status(common::SYSTEM, err_no, oss.str()); } - common::Status GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const override { - MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr), Freer()}; - if (!canonical_path_cstr) { - return ReportSystemError("realpath", path); - } - canonical_path.assign(canonical_path_cstr.get()); - return Status::OK(); - } - - common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override { - dlerror(); // clear any old error_str - *handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL)); - char* error_str = dlerror(); - if (!*handle) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to load library " + library_filename + " with error: " + error_str); - } - return common::Status::OK(); - } - - common::Status UnloadDynamicLibrary(void* handle) const override { - if (!handle) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null library handle"); - } - dlerror(); // clear any old error_str - int retval = dlclose(handle); - char* error_str = dlerror(); - if (retval != 0) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to unload library with error: " + std::string(error_str)); - } - return common::Status::OK(); - } - - common::Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override { - dlerror(); // clear any old error str - - // search global space if handle is nullptr. - // value of RTLD_DEFAULT differs across posix platforms (-2 on macos, 0 on linux). - handle = handle ? handle : RTLD_DEFAULT; - *symbol = dlsym(handle, symbol_name.c_str()); - - char* error_str = dlerror(); - if (error_str) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to get symbol " + symbol_name + " with error: " + error_str); - } - // it's possible to get a NULL symbol in our case when Schemas are not custom. - return common::Status::OK(); - } - - std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override { - std::string filename; - if (version.empty()) { - filename = "lib" + name + ".so"; - } else { - filename = "lib" + name + ".so" + "." + version; - } - return filename; - } - // \brief returns a value for the queried variable name (var_name) std::string GetEnvironmentVar(const std::string& var_name) const override { @@ -495,7 +310,7 @@ class PosixEnv : public Env { PosixEnv() { cpuinfo_available_ = cpuinfo_initialize(); if (!cpuinfo_available_) { - LOGS_DEFAULT(INFO) << "cpuinfo_initialize failed"; + std::cout << "cpuinfo_initialize failed"; } } bool cpuinfo_available_{false}; diff --git a/src/core/platform/windows/env.cc b/src/core/platform/windows/env.cc index f7b063f..9dab8a9 100644 --- a/src/core/platform/windows/env.cc +++ b/src/core/platform/windows/env.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -27,14 +27,14 @@ limitations under the License. #include #include -#include -#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" #include +#include #include "core/platform/path_lib.h" // for LoopDir() @@ -42,615 +42,380 @@ EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { - class UnmapFileParam { - public: - void* addr; - size_t len; - }; - - - - std::wstring Basename(const std::wstring& path) { - auto basename_index = path.find_last_of(L"/\\") + 1; // results in 0 if no separator is found - return path.substr(basename_index); - } - - class WindowsThread : public EnvThread { - private: - struct Param { - const ORTCHAR_T* name_prefix; - int index; - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); - Eigen::ThreadPoolInterface* param; - std::optional affinity; - Param(const ORTCHAR_T* name_prefix1, - int index1, - unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), - Eigen::ThreadPoolInterface* param1) - : name_prefix(name_prefix1), - index(index1), - start_address(start_address1), - param(param1) {} - }; - - public: - WindowsThread(const ORTCHAR_T* name_prefix, int index, - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, - const ThreadOptions& thread_options) { - ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); - - - std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param); - if (narrow(index) < thread_options.affinities.size()) { - local_param->affinity = thread_options.affinities[index]; - } - - { - _set_errno(0); - _set_doserrno(0); - auto th_handle = _beginthreadex(nullptr, thread_options.stack_size, ThreadMain, - local_param.get(), 0, - &threadID); - if (th_handle == 0) { - auto dos_error = _doserrno; - auto [err, msg] = GetErrnoInfo(); - ORT_THROW("WindowThread:_beginthreadex failed with errno:", err, " message:", msg, - " doserrno:", dos_error); - } - local_param.release(); - hThread.reset(reinterpret_cast(th_handle)); - // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. - } - } - - ~WindowsThread() { - { - DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); - assert(waitStatus != WAIT_FAILED); - } - } - - private: - typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription); + +std::wstring Basename(const std::wstring& path) { + auto basename_index = path.find_last_of(L"/\\") + 1; // results in 0 if no separator is found + return path.substr(basename_index); +} + +class WindowsThread : public EnvThread { + private: + struct Param { + const ORTCHAR_T* name_prefix; + int index; + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); + Eigen::ThreadPoolInterface* param; + std::optional affinity; + Param(const ORTCHAR_T* name_prefix1, + int index1, + unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), + Eigen::ThreadPoolInterface* param1) + : name_prefix(name_prefix1), + index(index1), + start_address(start_address1), + param(param1) {} + }; + + public: + WindowsThread(const ORTCHAR_T* name_prefix, int index, + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, + const ThreadOptions& thread_options) { + ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); + + std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param); + if (narrow(index) < thread_options.affinities.size()) { + local_param->affinity = thread_options.affinities[index]; + } + + { + _set_errno(0); + _set_doserrno(0); + auto th_handle = _beginthreadex(nullptr, thread_options.stack_size, ThreadMain, + local_param.get(), 0, + &threadID); + if (th_handle == 0) { + auto dos_error = _doserrno; + auto [err, msg] = GetErrnoInfo(); + ORT_THROW("WindowThread:_beginthreadex failed with errno:", err, " message:", msg, + " doserrno:", dos_error); + } + local_param.release(); + hThread.reset(reinterpret_cast(th_handle)); + // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. + } + } + + ~WindowsThread() { + { + DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); + FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED); + } + } + + private: + typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription); #pragma warning(push) #pragma warning(disable : 6387) - static unsigned __stdcall ThreadMain(void* param) { - std::unique_ptr p(static_cast(param)); - - // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox) - // so we need to ensure it's available before calling. - HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll")); - if (kernelModule != nullptr) { - auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription"); - if (setThreadDescriptionFn != nullptr) { - const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" - : p->name_prefix; - std::wostringstream oss; - oss << name_prefix << "-" << p->index; - // Ignore any errors - (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str()); - } - } - - unsigned ret = 0; - ORT_TRY{ - - -ret = p->start_address(p->index, p->param); - } - ORT_CATCH(...) { - p->param->Cancel(); - ret = 1; - } - return ret; - } + static unsigned __stdcall ThreadMain(void* param) { + std::unique_ptr p(static_cast(param)); + + // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox) + // so we need to ensure it's available before calling. + HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll")); + if (kernelModule != nullptr) { + auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription"); + if (setThreadDescriptionFn != nullptr) { + const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" + : p->name_prefix; + std::wostringstream oss; + oss << name_prefix << "-" << p->index; + // Ignore any errors + (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str()); + } + } + + unsigned ret = 0; + ORT_TRY { + if (p->affinity.has_value() && !p->affinity->empty()) { + int group_id = -1; + KAFFINITY mask = 0; + constexpr KAFFINITY bit = 1; + const WindowsEnv& env = WindowsEnv::Instance(); + for (auto global_processor_id : *p->affinity) { + auto processor_info = env.GetProcessorAffinityMask(global_processor_id); + if (processor_info.local_processor_id > -1 && + processor_info.local_processor_id < sizeof(KAFFINITY) * CHAR_BIT) { + mask |= bit << processor_info.local_processor_id; + } else { + // Logical processor id starts from 0 internally, but in ort API, it starts from 1, + // that's why id need to increase by 1 when logging. + std::cout << "Cannot set affinity for thread " << GetCurrentThreadId() + << ", processor " << global_processor_id + 1 << " does not exist"; + group_id = -1; + mask = 0; + break; + } + if (group_id == -1) { + group_id = processor_info.group_id; + } else if (group_id != processor_info.group_id) { + std::cout << "Cannot set cross-group affinity for thread " + << GetCurrentThreadId() << ", first on group " + << group_id << ", then on " << processor_info.group_id; + group_id = -1; + mask = 0; + break; + } + } // for + if (group_id > -1 && mask) { + GROUP_AFFINITY thread_affinity = {}; + thread_affinity.Group = static_cast(group_id); + thread_affinity.Mask = mask; + if (SetThreadGroupAffinity(GetCurrentThread(), &thread_affinity, nullptr)) { + std::cout << "SetThreadAffinityMask done for thread: " << GetCurrentThreadId() + << ", group_id: " << thread_affinity.Group + << ", mask: " << thread_affinity.Mask; + } else { + const auto error_code = GetLastError(); + std::cout << "SetThreadAffinityMask failed for thread: " << GetCurrentThreadId() + << ", index: " << p->index + << ", mask: " << *p->affinity + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code) + << ". Specify the number of threads explicitly so the affinity is not set."; + } + } + } + + ret = p->start_address(p->index, p->param); + } + ORT_CATCH(...) { + p->param->Cancel(); + ret = 1; + } + return ret; + } #pragma warning(pop) - static void CustomThreadMain(void* param) { - std::unique_ptr p(static_cast(param)); - ORT_TRY{ - p->start_address(p->index, p->param); - } - ORT_CATCH(...) { - p->param->Cancel(); - } - } - unsigned threadID = 0; - wil::unique_handle hThread; - }; + static void CustomThreadMain(void* param) { + std::unique_ptr p(static_cast(param)); + ORT_TRY { + p->start_address(p->index, p->param); + } + ORT_CATCH(...) { + p->param->Cancel(); + } + } + unsigned threadID = 0; + wil::unique_handle hThread; +}; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26409) #endif - EnvThread* WindowsEnv::CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), - Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { - return new WindowsThread(name_prefix, index, start_address, param, thread_options); - } +EnvThread* WindowsEnv::CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), + Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { + return new WindowsThread(name_prefix, index, start_address, param, thread_options); +} #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif - Env& Env::Default() { - return WindowsEnv::Instance(); - } - - void WindowsEnv::SleepForMicroseconds(int64_t micros) const { - Sleep(static_cast(micros) / 1000); - } - - - int WindowsEnv::DefaultNumCores() { - return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); - } - - int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); - - } - - std::vector WindowsEnv::GetDefaultThreadAffinities() const { - return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; - } - - int WindowsEnv::GetL2CacheSize() const { - return l2_cache_size_; - } - - WindowsEnv& WindowsEnv::Instance() { - static WindowsEnv default_env; - return default_env; - } - - PIDType WindowsEnv::GetSelfPid() const { - return GetCurrentProcessId(); - } - - Status WindowsEnv::GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const { - wil::unique_hfile file_handle{ - CreateFile2(file_path, FILE_READ_ATTRIBUTES, FILE_SHARE_READ, OPEN_EXISTING, NULL) }; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - LARGE_INTEGER filesize; - if (!GetFileSizeEx(file_handle.get(), &filesize)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileSizeEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - if (static_cast(filesize.QuadPart) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileLength: File is too large"); - } - length = static_cast(filesize.QuadPart); - return Status::OK(); - } - - common::Status WindowsEnv::GetFileLength(int fd, /*out*/ size_t& file_size) const { - using namespace common; - if (fd < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); - } - - struct _stat buf; - int rc = _fstat(fd, &buf); - if (rc < 0) { - return Status(SYSTEM, errno); - } - - if (buf.st_size < 0) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); - } - - if (static_cast(buf.st_size) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); - } - - file_size = static_cast(buf.st_size); - return Status::OK(); - } - - Status WindowsEnv::ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, - const gsl::span buffer) const { - ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); - ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); - ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); - wil::unique_hfile file_handle{ - CreateFile2(file_path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL) }; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (length == 0) - return Status::OK(); - - if (offset > 0) { - LARGE_INTEGER current_position; - current_position.QuadPart = offset; - if (!SetFilePointerEx(file_handle.get(), current_position, ¤t_position, FILE_BEGIN)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SetFilePointerEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - } - - size_t total_bytes_read = 0; - while (total_bytes_read < length) { - constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time - const size_t bytes_remaining = length - total_bytes_read; - const DWORD bytes_to_read = static_cast(std::min(bytes_remaining, k_max_bytes_to_read)); - DWORD bytes_read; - - if (!ReadFile(file_handle.get(), buffer.data() + total_bytes_read, bytes_to_read, &bytes_read, nullptr)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (bytes_read != bytes_to_read) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail: unexpected end"); - } - - total_bytes_read += bytes_read; - } - - return Status::OK(); - } - - - - common::Status WindowsEnv::GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const { - // adapted from MSVC STL std::filesystem::canonical() implementation - // https://github.com/microsoft/STL/blob/ed3cbf36416a385828e7a5987ca52cb42882d84b/stl/inc/filesystem#L2986 - CREATEFILE2_EXTENDED_PARAMETERS param; - memset(¶m, 0, sizeof(param)); - param.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS); - param.dwFileFlags = FILE_FLAG_BACKUP_SEMANTICS; - wil::unique_hfile file_handle{ CreateFile2( - path.c_str(), - FILE_READ_ATTRIBUTES, - FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - OPEN_EXISTING, - ¶m) }; - - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(path)), " fail, errcode = ", - error_code, " - ", std::system_category().message(error_code)); - } - - constexpr DWORD initial_buffer_size = MAX_PATH; - std::vector result_buffer{}; - result_buffer.resize(initial_buffer_size); - - while (true) { - const DWORD result_length = GetFinalPathNameByHandleW( - file_handle.get(), - result_buffer.data(), - static_cast(result_buffer.size()), - 0); - - ORT_RETURN_IF_NOT( - result_length > 0, "GetFinalPathNameByHandle() failed: ", GetLastError()); - - if (result_length < result_buffer.size()) { // buffer is large enough - canonical_path.assign(result_buffer.data(), result_length); - break; - } - - // need larger buffer - result_buffer.resize(result_length); - } - - // update prefixes - if (canonical_path.find(ORT_TSTR(R"(\\?\)")) == 0) { - if (canonical_path.size() > 6 && - (ORT_TSTR('A') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('Z') || - ORT_TSTR('a') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('z')) && - canonical_path[5] == ORT_TSTR(':')) { - // "\\?\:" -> ":" - canonical_path.erase(0, 4); - } - else if (canonical_path.find(ORT_TSTR(R"(UNC\)"), 4) == 4) { - // "\\?\UNC\" -> "\\" - canonical_path.erase(2, 6); - } - } - - return Status::OK(); - } - - // Return the path of the executable/shared library for the current running code. This is to make it - // possible to load other shared libraries installed next to our core runtime code. - PathString WindowsEnv::GetRuntimePath() const { - wchar_t buffer[MAX_PATH]; - if (!GetModuleFileNameW(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) { - return PathString(); - } - - // Remove the filename at the end, but keep the trailing slash - PathString path(buffer); - auto slash_index = path.find_last_of(ORT_TSTR('\\')); - if (slash_index == std::string::npos) { - // Windows supports forward slashes - slash_index = path.find_last_of(ORT_TSTR('/')); - if (slash_index == std::string::npos) { - return PathString(); - } - } - return path.substr(0, slash_index + 1); - } - - Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const { -#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP - * handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0); -#else - // TODO: in most cases, the path name is a relative path and the behavior of the following line of code is undefined. - * handle = ::LoadLibraryExW(wlibrary_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); +Env& Env::Default() { + return WindowsEnv::Instance(); +} + + +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" #endif - if (!*handle) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW( - FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), - 0, NULL); - std::wostringstream oss; - oss << L"LoadLibrary failed with error " << error_code << L" \"" << s.c_str() << L"\" when trying to load \"" << wlibrary_filename << L"\""; - std::wstring errmsg = oss.str(); - // TODO: trim the ending '\r' and/or '\n' - common::Status status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - return status; - } - return Status::OK(); - } - - Status WindowsEnv::UnloadDynamicLibrary(void* handle) const { - if (::FreeLibrary(reinterpret_cast(handle)) == 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "FreeLibrary failed with error ", error_code, " - ", std::system_category().message(error_code)); - } - return Status::OK(); - } - - namespace dlfcn_win32 { - // adapted from https://github.com/dlfcn-win32 version 1.3.1. - // Simplified to only support finding symbols in libraries that were linked against. - // If ORT dynamically loads a custom ops library using RegisterCustomOpsLibrary[_V2] the handle from the library load - // is explicitly provided in the call to GetSymbolFromLibrary. - // - /* Load Psapi.dll at runtime, this avoids linking caveat */ - bool MyEnumProcessModules(HANDLE hProcess, HMODULE* lphModule, DWORD cb, LPDWORD lpcbNeeded) { - using EnumProcessModulesFn = BOOL(WINAPI*)(HANDLE, HMODULE*, DWORD, LPDWORD); - static EnumProcessModulesFn EnumProcessModulesPtr = []() { - EnumProcessModulesFn fn = nullptr; - // Windows 7 and newer versions have K32EnumProcessModules in Kernel32.dll which is always pre-loaded - HMODULE psapi = GetModuleHandleA("Kernel32.dll"); - if (psapi) { - fn = (EnumProcessModulesFn)(LPVOID)GetProcAddress(psapi, "K32EnumProcessModules"); - } - - return fn; - }(); - - if (EnumProcessModulesPtr == nullptr) { - return false; - } - - return EnumProcessModulesPtr(hProcess, lphModule, cb, lpcbNeeded); - } - - void* SearchModulesForSymbol(const char* name) { - HANDLE current_proc = GetCurrentProcess(); - DWORD size = 0; - void* symbol = nullptr; - - // GetModuleHandle(NULL) only returns the current program file. So if we want to get ALL loaded module including - // those in linked DLLs, we have to use EnumProcessModules(). - if (MyEnumProcessModules(current_proc, nullptr, 0, &size) != false) { - size_t num_handles = size / sizeof(HMODULE); - std::unique_ptr modules = std::make_unique(num_handles); - HMODULE* modules_ptr = modules.get(); - DWORD cb_needed = 0; - if (MyEnumProcessModules(current_proc, modules_ptr, size, &cb_needed) != 0 && size == cb_needed) { - for (size_t i = 0; i < num_handles; i++) { - symbol = GetProcAddress(modules[i], name); - if (symbol != nullptr) { - break; - } - } - } - } - - return symbol; - } - } // namespace dlfcn_win32 - - Status WindowsEnv::GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const { - Status status = Status::OK(); - - // global search to replicate dlsym RTLD_DEFAULT if handle is nullptr - if (handle == nullptr) { - *symbol = dlfcn_win32::SearchModulesForSymbol(symbol_name.c_str()); - } - else { - *symbol = ::GetProcAddress(reinterpret_cast(handle), symbol_name.c_str()); - } - - if (!*symbol) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), 0, NULL); - std::wostringstream oss; - oss << L"Failed to find symbol " << ToWideString(symbol_name) << L" in library, error code: " - << error_code << L" \"" << s.c_str() << L"\""; - std::wstring errmsg = oss.str(); - // TODO: trim the ending '\r' and/or '\n' - status = Status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - } - - return status; - } - - std::string WindowsEnv::FormatLibraryFileName(const std::string& name, const std::string& version) const { - ORT_UNUSED_PARAMETER(name); - ORT_UNUSED_PARAMETER(version); - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); - } - - - - // \brief returns a value for the queried variable name (var_name) - std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { - // Why getenv() should be avoided on Windows: - // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv - // Instead use the Win32 API: GetEnvironmentVariableA() - - // Max limit of an environment variable on Windows including the null-terminating character - constexpr DWORD kBufferSize = 32767; - - // Create buffer to hold the result - std::string buffer(kBufferSize, '\0'); - - // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. - // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. - // Therefore, If the function succeeds, kBufferSize should be larger than char_count. - auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer.data(), kBufferSize); - - if (kBufferSize > char_count) { - buffer.resize(char_count); - return buffer; - } - - // Else either the call was failed, or the buffer wasn't large enough. - // TODO: Understand the reason for failure by calling GetLastError(). - // If it is due to the specified environment variable being found in the environment block, - // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. - // For now, we assume that the environment variable is not found. - - return std::string(); - } - - /* - Read logical processor info from the map. - {-1,-1} stands for failure. - */ - ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) const { - if (global_processor_info_map_.count(global_processor_id)) { - return global_processor_info_map_.at(global_processor_id); - } - else { - return { -1, -1 }; - } - } - - WindowsEnv::WindowsEnv() { - l2_cache_size_ = 0; - InitializeCpuInfo(); - } - - /* - Discover all cores in a windows system. - Note - every "id" here, given it be group id, core id, or logical processor id, starts from 0. - */ - void WindowsEnv::InitializeCpuInfo() { - DWORD returnLength = 0; - GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &returnLength); - auto last_error = GetLastError(); - if (last_error != ERROR_INSUFFICIENT_BUFFER) { - return; - } - - std::unique_ptr allocation = std::make_unique(returnLength); - SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX* processorInfos = reinterpret_cast(allocation.get()); - - if (!GetLogicalProcessorInformationEx(RelationProcessorCore, processorInfos, &returnLength)) { - return; - } - - int core_id = 0; - int global_processor_id = 0; - const BYTE* iter = reinterpret_cast(processorInfos); - const BYTE* end = iter + returnLength; - std::stringstream log_stream; - - while (iter < end) { - auto processor_info = reinterpret_cast(iter); - auto size = processor_info->Size; - - // Discoverred a phyical core and it belongs exclusively to a single group - if (processor_info->Relationship == RelationProcessorCore && - processor_info->Processor.GroupCount == 1) { - log_stream << std::endl - << "core " << core_id + 1 << " consist of logical processors: "; - LogicalProcessors core_global_proc_ids; - constexpr KAFFINITY bit = 1; - constexpr int id_upper_bound = sizeof(KAFFINITY) * CHAR_BIT; - const auto& group_mask = processor_info->Processor.GroupMask[0]; - for (int logical_proessor_id = 0; logical_proessor_id < id_upper_bound; ++logical_proessor_id) { - if (group_mask.Mask & (bit << logical_proessor_id)) { - log_stream << global_processor_id + 1 << " "; - core_global_proc_ids.push_back(global_processor_id); - /* - * Build up a map between global processor id and local processor id. - * The map helps to bridge between ort API and windows affinity API - - * we need local processor id to build an affinity mask for a particular group. - */ - global_processor_info_map_.insert_or_assign(global_processor_id, - ProcessorInfo{ static_cast(group_mask.Group), - logical_proessor_id }); - global_processor_id++; - } - } - cores_.push_back(std::move(core_global_proc_ids)); - core_id++; - } - iter += size; - } - - DWORD newLength = 0; - GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); - last_error = GetLastError(); - if (last_error != ERROR_INSUFFICIENT_BUFFER) { - return; - } - - if (newLength > returnLength) { - // Re-allocate - allocation = std::make_unique(newLength); - processorInfos = reinterpret_cast(allocation.get()); - } - - if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { - return; - } - - iter = reinterpret_cast(processorInfos); - end = iter + newLength; - - while (iter < end) { - auto processor_info = reinterpret_cast(iter); - auto size = processor_info->Size; - - if (processor_info->Relationship == RelationCache && - processor_info->Cache.Level == 2) { - // L2 cache - l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); - break; - } - - iter += size; - } - - } +int WindowsEnv::DefaultNumCores() { + return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); +} + +int WindowsEnv::GetNumPhysicalCpuCores() const { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +} + +std::vector WindowsEnv::GetDefaultThreadAffinities() const { + return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; +} + +int WindowsEnv::GetL2CacheSize() const { + return l2_cache_size_; +} + +WindowsEnv& WindowsEnv::Instance() { + static WindowsEnv default_env; + return default_env; +} + +PIDType WindowsEnv::GetSelfPid() const { + return GetCurrentProcessId(); +} + +// \brief returns a value for the queried variable name (var_name) +std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { + // Why getenv() should be avoided on Windows: + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv + // Instead use the Win32 API: GetEnvironmentVariableA() + + // Max limit of an environment variable on Windows including the null-terminating character + constexpr DWORD kBufferSize = 32767; + + // Create buffer to hold the result + std::string buffer(kBufferSize, '\0'); + + // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. + // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. + // Therefore, If the function succeeds, kBufferSize should be larger than char_count. + auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer.data(), kBufferSize); + + if (kBufferSize > char_count) { + buffer.resize(char_count); + return buffer; + } + + // Else either the call was failed, or the buffer wasn't large enough. + // TODO: Understand the reason for failure by calling GetLastError(). + // If it is due to the specified environment variable being found in the environment block, + // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. + // For now, we assume that the environment variable is not found. + + return std::string(); +} + +/* +Read logical processor info from the map. +{-1,-1} stands for failure. +*/ +ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) const { + if (global_processor_info_map_.count(global_processor_id)) { + return global_processor_info_map_.at(global_processor_id); + } else { + return {-1, -1}; + } +} + +WindowsEnv::WindowsEnv() { + l2_cache_size_ = 0; + InitializeCpuInfo(); +} + +/* +Discover all cores in a windows system. +Note - every "id" here, given it be group id, core id, or logical processor id, starts from 0. +*/ +void WindowsEnv::InitializeCpuInfo() { + DWORD returnLength = 0; + GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &returnLength); + auto last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + std::cout << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + std::unique_ptr allocation = std::make_unique(returnLength); + SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX* processorInfos = reinterpret_cast(allocation.get()); + + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, processorInfos, &returnLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + std::cout << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + int core_id = 0; + int global_processor_id = 0; + const BYTE* iter = reinterpret_cast(processorInfos); + const BYTE* end = iter + returnLength; + std::stringstream log_stream; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + // Discoverred a phyical core and it belongs exclusively to a single group + if (processor_info->Relationship == RelationProcessorCore && + processor_info->Processor.GroupCount == 1) { + log_stream << std::endl + << "core " << core_id + 1 << " consist of logical processors: "; + LogicalProcessors core_global_proc_ids; + constexpr KAFFINITY bit = 1; + constexpr int id_upper_bound = sizeof(KAFFINITY) * CHAR_BIT; + const auto& group_mask = processor_info->Processor.GroupMask[0]; + for (int logical_proessor_id = 0; logical_proessor_id < id_upper_bound; ++logical_proessor_id) { + if (group_mask.Mask & (bit << logical_proessor_id)) { + log_stream << global_processor_id + 1 << " "; + core_global_proc_ids.push_back(global_processor_id); + /* + * Build up a map between global processor id and local processor id. + * The map helps to bridge between ort API and windows affinity API - + * we need local processor id to build an affinity mask for a particular group. + */ + global_processor_info_map_.insert_or_assign(global_processor_id, + ProcessorInfo{static_cast(group_mask.Group), + logical_proessor_id}); + global_processor_id++; + } + } + cores_.push_back(std::move(core_global_proc_ids)); + core_id++; + } + iter += size; + } + + DWORD newLength = 0; + GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); + last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + std::cout << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + if (newLength > returnLength) { + // Re-allocate + allocation = std::make_unique(newLength); + processorInfos = reinterpret_cast(allocation.get()); + } + + if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + std::cout << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + iter = reinterpret_cast(processorInfos); + end = iter + newLength; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + if (processor_info->Relationship == RelationCache && + processor_info->Cache.Level == 2) { + // L2 cache + l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); + break; + } + + iter += size; + } + + if (logging::LoggingManager::HasDefaultLogger()) { + std::cout << "Found total " << cores_.size() << " core(s) from windows system:"; + std::cout << log_stream.str(); + std::cout << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; + } +} } // namespace onnxruntime diff --git a/src/core/platform/windows/env.h b/src/core/platform/windows/env.h index 9e53a71..e66ca6f 100644 --- a/src/core/platform/windows/env.h +++ b/src/core/platform/windows/env.h @@ -50,25 +50,14 @@ class WindowsEnv : public Env { #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif - void SleepForMicroseconds(int64_t micros) const override; static int DefaultNumCores(); int GetNumPhysicalCpuCores() const override; std::vector GetDefaultThreadAffinities() const override; int GetL2CacheSize() const override; static WindowsEnv& Instance(); PIDType GetSelfPid() const override; - Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override; - common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override; - Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, - const gsl::span buffer) const override; - - - common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; - PathString GetRuntimePath() const override; - Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override; - Status UnloadDynamicLibrary(void* handle) const override; - Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override; - std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override; + + std::string GetEnvironmentVar(const std::string& var_name) const override; ProcessorInfo GetProcessorAffinityMask(int global_processor_id) const; diff --git a/src/core/platform/windows/stacktrace.cc b/src/core/platform/windows/stacktrace.cc index 3401507..cc23d70 100644 --- a/src/core/platform/windows/stacktrace.cc +++ b/src/core/platform/windows/stacktrace.cc @@ -30,7 +30,6 @@ class CaptureStackTrace { // Get the stack trace. Currently only enabled for a DEBUG build as we require the DbgHelp library. std::vector GetStackTrace() { #ifndef NDEBUG -// TVM need to run with shared CRT, so won't work with debug helper now #if (defined __cpp_lib_stacktrace) && !(defined _OPSCHEMA_LIB_) && !(defined _GAMING_XBOX) && !(defined ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) return detail::CaptureStackTrace().Trace(); #else diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index fbc7037..da54462 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -1,795 +1,933 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(MLAS_INC_DIR ${MLAS_ROOT}/../include) - -include_directories(${ONNXRUNTIME_INCLUDE_DIR}) - -#Set global compile flags for all the source code(including third_party code like protobuf) -#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch -if (MSVC) - if (CMAKE_VS_PLATFORM_NAME) - # Multi-platform generator - set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) - else() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() - if (onnxruntime_target_platform STREQUAL "ARM64") - set(onnxruntime_target_platform "ARM64") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM64EC") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(onnxruntime_target_platform "ARM") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") - set(onnxruntime_target_platform "x64") - enable_language(ASM_MASM) - elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") - set(onnxruntime_target_platform "x86") - enable_language(ASM_MASM) - message("Enabling SAFESEH for x86 build") - set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") - else() - message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - endif() -endif() - -# -# All hardware agnostic source files here -# hardware specific files would cause trouble in -# multi-target build -# -add_library(onnxruntime_mlas STATIC - ${MLAS_SRC_DIR}/mlasi.h - ${MLAS_SRC_DIR}/platform.cpp - ${MLAS_SRC_DIR}/threading.cpp - ${MLAS_SRC_DIR}/sgemm.cpp - ${MLAS_SRC_DIR}/halfgemm.cpp - ${MLAS_SRC_DIR}/qgemm.cpp - ${MLAS_SRC_DIR}/qdwconv.cpp - ${MLAS_SRC_DIR}/convolve.cpp - ${MLAS_SRC_DIR}/convsym.cpp - ${MLAS_SRC_DIR}/pooling.cpp - ${MLAS_SRC_DIR}/transpose.cpp - ${MLAS_SRC_DIR}/reorder.cpp - ${MLAS_SRC_DIR}/snchwc.cpp - ${MLAS_SRC_DIR}/activate.cpp - ${MLAS_SRC_DIR}/logistic.cpp - ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/erf.cpp - ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/quantize.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp - ${MLAS_SRC_DIR}/qladd.cpp - ${MLAS_SRC_DIR}/qlmul.cpp - ${MLAS_SRC_DIR}/qpostprocessor.cpp - ${MLAS_SRC_DIR}/qlgavgpool.cpp - ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp - ${MLAS_SRC_DIR}/cast.cpp -) - -target_sources(onnxruntime_mlas PRIVATE - ${MLAS_INC_DIR}/mlas_float16.h - ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h - ${MLAS_INC_DIR}/mlas_q4.h - ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas.h -) - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4_dq.cpp - ${MLAS_SRC_DIR}/q4gemm.cpp - ) -endif() - - -#TODO: set MASM flags properly -function(setup_mlas_source_for_windows) - - # - # Sources common for all platforms. - # - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ) - - #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake - #Don't use it for other platforms. - if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) - set(PREPROCESS_ARMASM_FLAGS "") - set(ARMASM_FLAGS "") - - if(onnxruntime_target_platform STREQUAL "ARM64") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm - ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm - ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm - ) - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm - ) - - string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") - string(APPEND ARMASM_FLAGS " -machine ARM64EC") - endif() - - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - string(APPEND ARMASM_FLAGS " -g") - endif() - - # Remove double quotes from flag strings. - separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") - separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") - - # Run the C precompiler on each input before the assembler. - foreach(asm_filename ${mlas_platform_preprocess_srcs}) - get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) - set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) - add_custom_command( - OUTPUT ${obj_filename} - COMMAND - cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} - COMMAND - armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} - DEPENDS ${asm_filename} - BYPRODUCTS ${preprocess_filename} - ) - target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) - endforeach() - elseif(onnxruntime_target_platform STREQUAL "ARM") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ) - elseif(onnxruntime_target_platform STREQUAL "x64") - - file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") - - file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/dgemm.cpp - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/sgemma.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm - ) - if(MSVC_VERSION GREATER_EQUAL 1933) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm - ) - endif() - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - endif() - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm - ) - endif() -endfunction() - -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/wasm_simd/*.cpp" - ) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp - ) - else() - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp" - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -elseif(MSVC) - setup_mlas_source_for_windows() -else() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - - if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) - set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - endif() - foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) - if (OSX_ARCH STREQUAL "arm64") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm64e") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm") - set(ARM TRUE) - elseif (OSX_ARCH STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (OSX_ARCH STREQUAL "i386") - set(X86 TRUE) - endif() - endforeach() - elseif(ANDROID) - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - #Linux/FreeBSD/PowerPC/... - #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` - #Example values: - #arm64v8/ubuntu -> aarch64 - #arm32v6/alpine -> armv7l - #arm32v7/centos -> armv7l - #ppc64le/debian -> ppc64le - #s390x/ubuntu -> s390x - #ppc64le/busybox -> ppc64le - #arm64v8/ubuntu -> aarch64 - #Android: armv7-a aarch64 i686 x86_64 - #chasun: I don't think anyone uses 'arm64' - if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") - set(ARM TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") - set(POWER TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") - set(LOONGARCH64 TRUE) - endif() - endif() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) - set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) - endif() - #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below - #and split MLAS to multiple static libraries. - #Otherwise, it works like if(...) elseif(...) elseif(...) endif() - set(MLAS_SOURCE_IS_NOT_SET 1) - if(ARM) - enable_language(ASM) - - set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) - enable_language(ASM) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S - ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) - set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(POWER AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp - ${MLAS_SRC_DIR}/power/QuantizePower.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") - - check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) - if (HAS_POWER9) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") - endif() - - check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) - if(HAS_POWER10) - set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") - check_cxx_source_compiles(" - #include - int main() { - __vector_quad acc0; - __builtin_mma_xxsetaccz (&acc0); - return 0; - }" - COMPILES_P10 - ) - if(COMPILES_P10) - check_cxx_source_compiles(" - #ifdef _AIX - #define POWER_10 0x40000 - #define POWER_10_ANDUP (POWER_10) - #include - #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) - int main() { - bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); - return 0; - } - #else - #include - int main() { - unsigned long hwcap2 = getauxval(AT_HWCAP2); - bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); - return 0; - } - } - #endif" - HAS_P10_RUNTIME - ) - if (HAS_P10_RUNTIME) - set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - endif() - set(mlas_platform_srcs_power10 - ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") - set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") - set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_power10} - ) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S - ) - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ) - - # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own - # implementation to avoid external dependency. - if(ANDROID) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S - ) - endif() - - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - # Forward the flags for the minimum target platform version from the C - # compiler to the assembler. This works around CMakeASMCompiler.cmake.in - # not including the logic to set this flag for the assembler. - set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") - - # The LLVM assembler does not support the .arch directive to enable instruction - # set extensions and also doesn't support AVX-512F instructions without - # turning on support via command-line option. Group the sources by the - # instruction set extension and explicitly set the compiler flag as appropriate. - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S - ) - if(NOT APPLE) - set(mlas_platform_srcs_sse2 - ${mlas_platform_srcs_sse2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S - ) - endif() - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S - ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs_avx2 - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S - ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp - ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) - set(mlas_platform_srcs_avx2 - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S - ) - endif() -message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") - -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") - message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") -else() - message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") -endif() - set(mlas_platform_srcs_avx512f - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - - set(mlas_platform_srcs_avx512core - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") - - set(mlas_platform_srcs_avx512vnni - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${mlas_platform_srcs_avx512f} - ${mlas_platform_srcs_avx512core} - ${mlas_platform_srcs_avx512vnni} - ) - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - if(NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S - ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S - ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp") - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -endif() - -foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - target_link_libraries(${mlas_target} Microsoft.GSL::GSL) - - set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") -endforeach() - -if (WIN32) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") - if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") - endif() -endif() - -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() -endblock() - - - - # - # Command line tool for quantization and de-quantization of 2-D fp32 tensors - # based on block-wise quantization of int4 - # - - add_executable(onnxruntime_mlas_q4dq - ${MLAS_SRC_DIR}/q4_dq_cli.cpp - ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) - if(NOT MLAS_NO_ONNXRUNTIME) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) - endif() - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) - - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() - +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLAS_INC_DIR ${MLAS_ROOT}/../include) + +include_directories(${ONNXRUNTIME_INCLUDE_DIR}) + +#Set global compile flags for all the source code(including third_party code like protobuf) +#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() + + +# mlas_private_compile_definitions contains compile definitions that are private to onnxruntime_mlas and targets which +# use internal MLAS headers like mlasi.h. +set(mlas_private_compile_definitions) +# +# All hardware agnostic source files here +# hardware specific files would cause trouble in +# multi-target build +# +onnxruntime_add_static_library(onnxruntime_mlas + ${MLAS_SRC_DIR}/mlasi.h + ${MLAS_SRC_DIR}/platform.cpp + ${MLAS_SRC_DIR}/threading.cpp + ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qdwconv.cpp + ${MLAS_SRC_DIR}/convolve.cpp + ${MLAS_SRC_DIR}/convsym.cpp + ${MLAS_SRC_DIR}/pooling.cpp + ${MLAS_SRC_DIR}/transpose.cpp + ${MLAS_SRC_DIR}/reorder.cpp + ${MLAS_SRC_DIR}/snchwc.cpp + ${MLAS_SRC_DIR}/activate.cpp + ${MLAS_SRC_DIR}/logistic.cpp + ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/eltwise.h + ${MLAS_SRC_DIR}/eltwise.cpp + ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp + ${MLAS_SRC_DIR}/quantize.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp + ${MLAS_SRC_DIR}/qladd.cpp + ${MLAS_SRC_DIR}/qlmul.cpp + ${MLAS_SRC_DIR}/qpostprocessor.cpp + ${MLAS_SRC_DIR}/qlgavgpool.cpp + ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/rotary_embedding.h + ${MLAS_SRC_DIR}/rotary_embedding.cpp + ${MLAS_SRC_DIR}/softmax.h + ${MLAS_SRC_DIR}/saturation_check.cpp +) + +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + +if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4_dq.cpp + ${MLAS_SRC_DIR}/q4gemm.cpp + ) +endif() + +set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) + +#TODO: set MASM flags properly +function(setup_mlas_source_for_windows) + + # + # Sources common for all platforms. + # + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ) + + #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake + #Don't use it for other platforms. + if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) + set(PREPROCESS_ARMASM_FLAGS "") + set(ARMASM_FLAGS "") + + if(onnxruntime_target_platform STREQUAL "ARM64") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon.h + ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm + ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm + ) + + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) + setup_kleidiai() + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm + ) + + string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") + string(APPEND ARMASM_FLAGS " -machine ARM64EC") + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + string(APPEND ARMASM_FLAGS " -g") + endif() + + # Remove double quotes from flag strings. + separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") + separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") + + # Run the C precompiler on each input before the assembler. + foreach(asm_filename ${mlas_platform_preprocess_srcs}) + get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) + set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) + set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) + add_custom_command( + OUTPUT ${obj_filename} + COMMAND + cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} + COMMAND + armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} + DEPENDS ${asm_filename} + BYPRODUCTS ${preprocess_filename} + ) + target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) + endforeach() + elseif(onnxruntime_target_platform STREQUAL "ARM") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ) + elseif(onnxruntime_target_platform STREQUAL "x64") + + file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") + + file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/dgemm.cpp + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/sgemma.asm + ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm + ) + + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm PROPERTIES COMPILE_FLAGS "-DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm + ) + endif() +endfunction() + +function(setup_kleidiai) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp + ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp + ) + target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) + set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() +endfunction() + +function (setup_arm_neon_nchwc) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/sconv.h + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp + ) + list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) + set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) +endfunction () + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/wasm_simd/*.cpp" + ) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp + ) + if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmrelaxedsimd.cpp + ) + endif() + else() + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp" + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +elseif(MSVC) + setup_mlas_source_for_windows() +else() + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + + if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) + set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + endif() + foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) + if (OSX_ARCH STREQUAL "arm64") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm64e") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm") + set(ARM TRUE) + elseif (OSX_ARCH STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (OSX_ARCH STREQUAL "i386") + set(X86 TRUE) + endif() + endforeach() + elseif(ANDROID) + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARM TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ARM64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(X86 TRUE) + endif() + else() + #Linux/FreeBSD/PowerPC/... + #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` + #Example values: + #arm64v8/ubuntu -> aarch64 + #arm32v6/alpine -> armv7l + #arm32v7/centos -> armv7l + #ppc64le/debian -> ppc64le + #s390x/ubuntu -> s390x + #ppc64le/busybox -> ppc64le + #arm64v8/ubuntu -> aarch64 + #Android: armv7-a aarch64 i686 x86_64 + #chasun: I don't think anyone uses 'arm64' + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") + set(ARM TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") + set(POWER TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^s390x$") + set(S390X TRUE) + endif() + endif() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + endif() + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) + set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) + endif() + #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below + #and split MLAS to multiple static libraries. + #Otherwise, it works like if(...) elseif(...) elseif(...) endif() + set(MLAS_SOURCE_IS_NOT_SET 1) + if(ARM) + enable_language(ASM) + + set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) + enable_language(ASM) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon.h + ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ) + + # Conditionally add the SVE implementation if compiler supports it + if (onnxruntime_USE_SVE) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) + set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) + endif() + + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) + setup_kleidiai() + endif() + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + + if (NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(POWER AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp + ${MLAS_SRC_DIR}/power/QuantizePower.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + + check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) + if (HAS_POWER9) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") + endif() + + check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) + if(HAS_POWER10) + set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") + check_cxx_source_compiles(" + #include + int main() { + __vector_quad acc0; + __builtin_mma_xxsetaccz (&acc0); + return 0; + }" + COMPILES_P10 + ) + if(COMPILES_P10) + check_cxx_source_compiles(" + #ifdef _AIX + #define POWER_10 0x40000 + #define POWER_10_ANDUP (POWER_10) + #include + #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) + int main() { + bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); + return 0; + } + #else + #include + int main() { + unsigned long hwcap2 = getauxval(AT_HWCAP2); + bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); + return 0; + } + #endif" + HAS_P10_RUNTIME + ) + if (HAS_P10_RUNTIME) + set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + endif() + set(mlas_platform_srcs_power10 + ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") + set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_power10} + ) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S + ) + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ) + + # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own + # implementation to avoid external dependency. + if(ANDROID) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S + ) + endif() + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + # Forward the flags for the minimum target platform version from the C + # compiler to the assembler. This works around CMakeASMCompiler.cmake.in + # not including the logic to set this flag for the assembler. + set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") + + # The LLVM assembler does not support the .arch directive to enable instruction + # set extensions and also doesn't support AVX-512F instructions without + # turning on support via command-line option. Group the sources by the + # instruction set extension and explicitly set the compiler flag as appropriate. + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S + ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S + ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs_avx2 + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() + +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") +endif() + set(mlas_platform_srcs_avx512f + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") + + set(mlas_platform_srcs_avx512core + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") + + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${mlas_platform_srcs_avx512f} + ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} + ) + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + if(NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_lasx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(S390X AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/s390x/SgemmKernel.cpp + ${MLAS_SRC_DIR}/s390x/SgemmKernelZVECTOR.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/s390x/DgemmKernel.cpp + ${MLAS_SRC_DIR}/s390x/Quantize.cpp + ${MLAS_SRC_DIR}/s390x/QuantizeZVECTOR.cpp + ${MLAS_SRC_DIR}/s390x/qgemm_kernel_zvector.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/s390x/SgemmKernel.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/s390x/SgemmKernelZVECTOR.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mvx -mzvector -march=z15") + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +endif() + +foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + target_link_libraries(${mlas_target} Microsoft.GSL::GSL) + + target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions}) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") +endforeach() + +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") + endif() +endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_mlas EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + + + + + + + # + # Command line tool for quantization and de-quantization of 2-D fp32 tensors + # based on block-wise quantization of int4 + # + + onnxruntime_add_executable(onnxruntime_mlas_q4dq + ${MLAS_SRC_DIR}/q4_dq_cli.cpp + ) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") + + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) + if(NOT MLAS_NO_ONNXRUNTIME) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) + endif() + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) + endif() + if(NOT WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${CMAKE_DL_LIBS}) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) + endif() + + if(WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) + + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() + + diff --git a/src/lib/activate.cpp b/src/lib/activate.cpp index df3b884..5dd8244 100644 --- a/src/lib/activate.cpp +++ b/src/lib/activate.cpp @@ -141,7 +141,7 @@ struct MLAS_ACTIVATION_FUNCTION return _mm_blendv_ps(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); #elif defined(MLAS_SSE2_INTRINSICS) return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); #elif defined(MLAS_LSX_INTRINSICS) return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); diff --git a/src/lib/activate_fp16.cpp b/src/lib/activate_fp16.cpp index 776ec67..564c6a4 100644 --- a/src/lib/activate_fp16.cpp +++ b/src/lib/activate_fp16.cpp @@ -51,12 +51,12 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - return MlasMaximumFloat16x8(ZeroVec, Value); + return MlasMaximumFloat16(ZeroVec, Value); } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value); + return MlasMaximumFloat16(MlasToLowHalfFloat16x4(ZeroVec), Value); } }; @@ -75,7 +75,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast); + MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16(Value, AlphaBroadcast); return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec), ValueTimesAlpha, Value); } @@ -83,7 +83,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { MLAS_FLOAT16X4 ValueTimesAlpha = - MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); + MlasMultiplyFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); return MlasBitwiseSelectFloat16x4( MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha, Value); @@ -539,16 +539,16 @@ struct MLAS_HALF_ACTIVATION_FUNCTION { MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); return Value; } }; @@ -573,19 +573,19 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); + Value = MlasMultiplyAddFloat16(Value, AlphaBroadcast, BetaBroadcast); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), + Value = MlasMultiplyAddFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), MlasToLowHalfFloat16x4(BetaBroadcast)); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); return Value; } @@ -692,7 +692,7 @@ MlasActivationKernel( MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc); MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); addsrc += 8; - Vector = MlasAddFloat16x8(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x8(buffer, Vector); buffer += 8; @@ -703,7 +703,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc); MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); addsrc += 4; - Vector = MlasAddFloat16x4(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x4(buffer, Vector); buffer += 4; @@ -715,7 +715,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 buf; std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_)); std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); - buf = MlasAddFloat16x4(buf, addbuf); + buf = MlasAddFloat16(buf, addbuf); buf = ActivationFunction.Activate(buf); MlasStorePartialFloat16x4(buffer, buf, n); } diff --git a/src/lib/amd64/ConvSymKernelAvx2.asm b/src/lib/amd64/ConvSymKernelAvx2.asm index a42d7ff..9c334be 100644 --- a/src/lib/amd64/ConvSymKernelAvx2.asm +++ b/src/lib/amd64/ConvSymKernelAvx2.asm @@ -23,6 +23,87 @@ INCLUDE ConvSymKernelCommon.inc INCLUDE AssembleAvxVnni.inc .list +extern CheckSaturationForVPMADDUBSW:proc + +CheckSaturation MACRO VecReg1Num, VecReg2Num + +; +; Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11). no RSI, RDI. +; + + push_reg rax + push_reg rcx + push_reg rdx + push_reg r8 + push_reg r9 + push_reg r10 + push_reg r11 + + sub rsp, 512 ; reserve space for 16 YMM registers (32 bytes) + +; +; Save YMM registers (YMM0 to YMM15). +; + + vmovdqu YMMWORD PTR [rsp], ymm0 + vmovdqu YMMWORD PTR [rsp+32], ymm1 + vmovdqu YMMWORD PTR [rsp+64], ymm2 + vmovdqu YMMWORD PTR [rsp+96], ymm3 + vmovdqu YMMWORD PTR [rsp+128], ymm4 + vmovdqu YMMWORD PTR [rsp+160], ymm5 + vmovdqu YMMWORD PTR [rsp+192], ymm6 + vmovdqu YMMWORD PTR [rsp+224], ymm7 + vmovdqu YMMWORD PTR [rsp+256], ymm8 + vmovdqu YMMWORD PTR [rsp+288], ymm9 + vmovdqu YMMWORD PTR [rsp+320], ymm10 + vmovdqu YMMWORD PTR [rsp+352], ymm11 + vmovdqu YMMWORD PTR [rsp+384], ymm12 + vmovdqu YMMWORD PTR [rsp+416], ymm13 + vmovdqu YMMWORD PTR [rsp+448], ymm14 + vmovdqu YMMWORD PTR [rsp+480], ymm15 + + lea rcx, [rsp+32*VecReg1Num] ; first operand (unsigned) + lea rdx, [rsp+32*VecReg2Num] ; second operand (signed) + + call CheckSaturationForVPMADDUBSW + +; +; Restore YMM registers. +; + + vmovdqu ymm0, YMMWORD PTR [rsp] + vmovdqu ymm1, YMMWORD PTR [rsp+32] + vmovdqu ymm2, YMMWORD PTR [rsp+64] + vmovdqu ymm3, YMMWORD PTR [rsp+96] + vmovdqu ymm4, YMMWORD PTR [rsp+128] + vmovdqu ymm5, YMMWORD PTR [rsp+160] + vmovdqu ymm6, YMMWORD PTR [rsp+192] + vmovdqu ymm7, YMMWORD PTR [rsp+224] + vmovdqu ymm8, YMMWORD PTR [rsp+256] + vmovdqu ymm9, YMMWORD PTR [rsp+288] + vmovdqu ymm10, YMMWORD PTR [rsp+320] + vmovdqu ymm11, YMMWORD PTR [rsp+352] + vmovdqu ymm12, YMMWORD PTR [rsp+384] + vmovdqu ymm13, YMMWORD PTR [rsp+416] + vmovdqu ymm14, YMMWORD PTR [rsp+448] + vmovdqu ymm15, YMMWORD PTR [rsp+480] + + add rsp, 512 ; clean up the reserved stack space + +; +; Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11), no RSI, RDI. +; + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdx + pop rcx + pop rax + + ENDM + ; ; Macro Description: ; @@ -50,9 +131,15 @@ INCLUDE AssembleAvxVnni.inc MultiplyAccumulateRowAvx2 MACRO Vec1Reg, Vec2Reg +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,0 +ENDIF vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd Vec1Reg,Vec1Reg,ymm3 +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,1 +ENDIF vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd Vec2Reg,Vec2Reg,ymm2 diff --git a/src/lib/cast.cpp b/src/lib/cast.cpp index 9b5800b..6b71fe3 100644 --- a/src/lib/cast.cpp +++ b/src/lib/cast.cpp @@ -33,6 +33,60 @@ MlasConvertHalfToFloatBuffer( } } + +void +MLASCALL +MlasConvertHalfToFloatBufferInParallel( + const MLAS_FP16* Source, + float* Destination, + size_t Count, + MLAS_THREADPOOL* ThreadPool +) +{ +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // If the ThreadPool is not available, use the single-threaded version. + MlasConvertHalfToFloatBuffer(Source, Destination, Count); +#else + // Check if the Tensor is long enough to use threads. + // Check if the Thread Pool is available. + // If not, execute single threaded conversion of half to float + if (!((Count > MLAS_MIN_TENSOR_SIZE_FOR_HALF_TO_FLOAT_CONVERSION_IN_PARALLEL) && ThreadPool)) { + MlasConvertHalfToFloatBuffer(Source, Destination, Count); + } + else { + + // Calculate the number of compute cycles per implementation + size_t num_compute_cycles; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasSSE3()) { + num_compute_cycles = Count >> 1; + } else if (MLAS_CPUIDINFO::GetCPUIDInfo().HasAVX2()) { + num_compute_cycles = Count >> 2; + } else { + num_compute_cycles = Count * 10; + } + + MLAS_THREADPOOL::TryParallelFor( + ThreadPool, Count, + // Tensor Op Cost + { + static_cast(Count * sizeof(MLAS_FP16)), // Size of no. of elements in bytes to be loaded + static_cast(Count * sizeof(float)), // Size of no. of elements in bytes to be stored + static_cast(num_compute_cycles), // No. of compute cycles required for the tensor op + }, + // Lambda function required by TryParallelFor method + [Source, Destination](std::ptrdiff_t first_span, std::ptrdiff_t last_span) { + MlasConvertHalfToFloatBuffer( + Source + first_span, + Destination + first_span, + static_cast(last_span - first_span)); + } + ); + } +#endif // BUILD_MLAS_NO_ONNXRUNTIME +} + void MLASCALL MlasConvertFloatToHalfBuffer( diff --git a/src/lib/fp16_neon_common.cpp b/src/lib/cast_kernel_neon.cpp similarity index 99% rename from src/lib/fp16_neon_common.cpp rename to src/lib/cast_kernel_neon.cpp index 29734c2..8a385c9 100644 --- a/src/lib/fp16_neon_common.cpp +++ b/src/lib/cast_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - fp16_neon_common.cpp + cast_kernel_neon.cpp Abstract: diff --git a/src/lib/compute.cpp b/src/lib/compute.cpp index 73df23e..4916062 100644 --- a/src/lib/compute.cpp +++ b/src/lib/compute.cpp @@ -20,6 +20,7 @@ Module Name: --*/ #include "mlasi.h" +#include "softmax.h" // // Bundles the constants for use by kernels written in assembly. @@ -68,12 +69,14 @@ MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits: // threads. // +template struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; - const float* Input; - float* Output; + float Sink; + const T* Input; + T* Output; size_t N; size_t D; }; @@ -244,9 +247,10 @@ Return Value: } } +template <> void MLASCALL -MlasComputeExp( +MlasComputeExp( const float* Input, float* Output, size_t N @@ -280,6 +284,20 @@ Return Value: #endif } +template <> +void MLASCALL +MlasComputeExp( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Exp_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Exp_Fp16 is not supported."); + } + dispatch->Exp_Fp16(Input, Output, N); +} + MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasComputeSumExpVector( @@ -483,7 +501,6 @@ Return Value: Input += 1; N -= 1; } - return Accumulator; } @@ -553,7 +570,6 @@ Return Value: Input += 1; N -= 1; } - return Maximum; } @@ -783,10 +799,18 @@ Return Value: } } +template void MlasComputeSoftmaxThreaded( void* Context, ptrdiff_t Index +); + +template <> +void +MlasComputeSoftmaxThreaded( + void* Context, + ptrdiff_t Index ) /*++ @@ -807,8 +831,8 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; - + const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; + // // Partition the operation along the N dimension. // @@ -825,6 +849,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -849,30 +874,34 @@ Return Value: // // Find the maximum value for the row. // + float Maximum; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); -#else - float Maximum = MlasReduceMaximumF32Kernel(Input, D); +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) + Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); +#else + Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. // float* Temp = LogSoftmax ? nullptr : Output; -#if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); + float Accumulation; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) + Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #else - float Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); + Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -881,19 +910,19 @@ Return Value: // float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); -#else +#else + MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #endif - } else { // // Normalize the softmax output. // float Parameters[] = {1.0f / Accumulation}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); @@ -906,15 +935,90 @@ Return Value: } } +template <> +void +MlasComputeSoftmaxThreaded( + void* Context, + ptrdiff_t Index +) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + softmax or log softmax operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + ThreadId - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ +{ + const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; + size_t n; + size_t CountN; + MlasPartitionWork(Index, WorkBlock->ThreadCountN, WorkBlock->N, &n, &CountN); + + const size_t D = WorkBlock->D; + const bool LogSoftmax = WorkBlock->LogSoftmax; + const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + + const MLAS_FP16* Input = WorkBlock->Input + n * D; + MLAS_FP16* Output = WorkBlock->Output + n * D; + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || + dispatch->ReduceMax_Fp16 == nullptr || + dispatch->SumExp_Fp16 == nullptr || + (LogSoftmax && dispatch->LogSoftmax_Fp16 == nullptr) || + (!LogSoftmax && dispatch->Softmax_Fp16 == nullptr)) { + MLAS_THROW_EX(std::runtime_error, "Lacks kernels for fp16 softmax."); + } + + while (CountN > 0) { + MLAS_FP16 Maximum = dispatch->ReduceMax_Fp16(Input, D); + MLAS_FP16 NegativeMaximum = Maximum.Negate(); + if (SmoothSoftmax && !NegativeMaximum.IsNegative()) { + NegativeMaximum = MLAS_FP16::FromBits(0); + } + + MLAS_FP16* Temp = LogSoftmax ? nullptr : Output; + MLAS_FP16 Accumulation = dispatch->SumExp_Fp16(Input, Temp, D, NegativeMaximum); + float accumulation_fp32 = Accumulation.ToFloat(); + + if (SmoothSoftmax) { + accumulation_fp32 += expf(NegativeMaximum.ToFloat()); + } + + if (LogSoftmax) { + dispatch->LogSoftmax_Fp16(Input, Output, D, NegativeMaximum, MLAS_FP16(std::log(accumulation_fp32))); + } else { + dispatch->Softmax_Fp16(Output, Output, D, MLAS_FP16(accumulation_fp32)); + } + + Input += D; + Output += D; + CountN--; + } +} + +template void MLASCALL MlasComputeSoftmax( - const float* Input, - float* Output, + const T* Input, + T* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -940,6 +1044,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -949,7 +1055,7 @@ Return Value: --*/ { - MLAS_SOFTMAX_WORK_BLOCK WorkBlock; + MLAS_SOFTMAX_WORK_BLOCK WorkBlock; // // Capture the softmax parameters to the work block. @@ -961,6 +1067,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -985,5 +1092,69 @@ Return Value: WorkBlock.ThreadCountN = ThreadCountN; - MlasExecuteThreaded(MlasComputeSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); + MlasExecuteThreaded(MlasComputeSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); +} + +template +void +MLASCALL +MlasComputeSoftmax( + const float* Input, + float* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + float Sink, + MLAS_THREADPOOL* ThreadPool +); + +template +void +MLASCALL +MlasComputeSoftmax( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + float Sink, + MLAS_THREADPOOL* ThreadPool +); + +template <> +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + if (!MlasHGemmSupported(TransA, TransB)) { + return false; + } + + const auto* softmax_dispatch = GetMlasPlatform().SoftmaxDispatch; + if (softmax_dispatch == nullptr || + softmax_dispatch->Tanh_Fp16 == nullptr || + softmax_dispatch->Softcap_Fp16 == nullptr || + softmax_dispatch->SumExp_Fp16 == nullptr || + softmax_dispatch->Softmax_Fp16 == nullptr || + softmax_dispatch->ReduceMax_Fp16 == nullptr) { + return false; + } + + return true; +} + +template <> +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + return true; } diff --git a/src/lib/convolve.cpp b/src/lib/convolve.cpp index ec79641..9518134 100644 --- a/src/lib/convolve.cpp +++ b/src/lib/convolve.cpp @@ -729,6 +729,82 @@ Return Value: } } +void +MlasConvExpandThenGemmSegmentedThreaded( + void* Context, + ptrdiff_t Index +) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + convolution operation. + + If using this, the entire convolution operation is parallelized on the + (batch size * group count) parameter and this routine has logic to + perform a specific thread's shard of the entire Convolution operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + Index - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ + +{ + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (static_cast(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if (bias != nullptr) { + bias += group * FilterCount; + } + float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K; + + MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize); + } +} + inline bool MlasConvTryMultithread( @@ -861,6 +937,12 @@ Return Value: --*/ { + // Override + if(GetMlasPlatform().MlasConvOverride != nullptr && + GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){ + return; + } + const size_t FilterCount = Parameters->FilterCount; const size_t OutputSize = Parameters->OutputSize; const size_t K = Parameters->K; @@ -884,8 +966,8 @@ Return Value: ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); - if (size_t(TargetThreadCount) >= BatchGroupCount) { - TargetThreadCount = ptrdiff_t(BatchGroupCount); + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); } MLAS_CONV_WORK_BLOCK WorkBlock; @@ -913,6 +995,30 @@ Return Value: #endif + if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) { + const size_t BatchGroupCount = BatchCount * GroupCount; + + ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + // // Iterate over each batch and group. // @@ -1094,6 +1200,13 @@ Return Value: --*/ { + // Override + if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && + GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, + InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, + Activation, WorkingBufferSize, Beta, ThreadPool)){ + return; + } // // Save the convolution parameters. // @@ -1295,8 +1408,20 @@ Return Value: Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + + if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { + + size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, + Parameters->FilterCount * Parameters->OutputSize, + static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)}); + TargetThreadCount = MaximumThreadCount; + if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { + TargetThreadCount = static_cast(Parameters->BatchCount * Parameters->GroupCount); + } + *WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread; + } } } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/src/lib/convsym.cpp b/src/lib/convsym.cpp index 0ea7bef..5591aa4 100644 --- a/src/lib/convsym.cpp +++ b/src/lib/convsym.cpp @@ -16,7 +16,8 @@ Module Name: --*/ #include "mlasi.h" -#include +#include + // // Define the prototypes of the platform optimized routines. // diff --git a/src/lib/dequantize.cpp b/src/lib/dequantize.cpp new file mode 100644 index 0000000..175d3f6 --- /dev/null +++ b/src/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/src/lib/dgemm.cpp b/src/lib/dgemm.cpp index 50c6274..bc341e6 100644 --- a/src/lib/dgemm.cpp +++ b/src/lib/dgemm.cpp @@ -26,7 +26,7 @@ Module Name: #define MLAS_DGEMM_TRANSA_ROWS 12 -#if defined (MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) +#if defined (MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) || defined (MLAS_TARGET_S390X) void MlasDgemmMultiplyBeta( @@ -530,7 +530,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/src/lib/dwconv.cpp b/src/lib/dwconv.cpp index d48d9cb..0fff937 100644 --- a/src/lib/dwconv.cpp +++ b/src/lib/dwconv.cpp @@ -43,7 +43,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]); MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x8(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x8(Output, Accumulator); @@ -61,7 +61,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x4(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x4(Output, Accumulator); @@ -80,7 +80,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x4(InputValue, FilterValue, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputValue, FilterValue, Accumulator); ChannelKernelOffset += Channels; } MlasStorePartialFloat16x4(Output, Accumulator, c); diff --git a/src/lib/eltwise.cpp b/src/lib/eltwise.cpp new file mode 100644 index 0000000..f63d71b --- /dev/null +++ b/src/lib/eltwise.cpp @@ -0,0 +1,71 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.cpp + +Abstract: + + This module implements routines to compute element-wise operations on two vectors. + + Currently supported element-wise operations: + - Add + +--*/ + +#include "mlasi.h" +#include "eltwise.h" + +template <> +void +MLASCALL +MlasEltwiseAdd( + const float* left, + const float* right, + float* output, + size_t N +) { + while (N > 0) { + if (N >= 4) { + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); + + MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); + + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + *output = *left + *right; + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} + + +template <> +void +MLASCALL +MlasEltwiseAdd( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().EltwiseDispatch; + if (dispatch == nullptr || dispatch->Add_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Add_Fp16 is not supported."); + } + dispatch->Add_Fp16(left, right, output, N); +} diff --git a/src/lib/eltwise.h b/src/lib/eltwise.h new file mode 100644 index 0000000..a8345c4 --- /dev/null +++ b/src/lib/eltwise.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + element-wise operations. + +--*/ +#pragma once + +#include "mlasi.h" + +struct MLAS_ELTWISE_DISPATCH { + /** + * @brief Compute the element-wise addition of the two given vectors + * @param left Address of the left operand + * @param right Address of the right operand + * @param output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input arrays + */ + typedef void(Add_Fp16_Fn)( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N + ); + + Add_Fp16_Fn* Add_Fp16 = nullptr; +}; diff --git a/src/lib/eltwise_kernel_neon.cpp b/src/lib/eltwise_kernel_neon.cpp new file mode 100644 index 0000000..415c128 --- /dev/null +++ b/src/lib/eltwise_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.cpp + +Abstract: + + This module implements the element-wise kernels for ARM NEON. + +--*/ + +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon = []() { + MLAS_ELTWISE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Add_Fp16 = eltwise_neon::Add_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/eltwise_kernel_neon.h b/src/lib/eltwise_kernel_neon.h new file mode 100644 index 0000000..d99a3e9 --- /dev/null +++ b/src/lib/eltwise_kernel_neon.h @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + element-wise operations on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N); + +} // namespace eltwise_neon diff --git a/src/lib/eltwise_kernel_neon_fp16.cpp b/src/lib/eltwise_kernel_neon_fp16.cpp new file mode 100644 index 0000000..decbdb5 --- /dev/null +++ b/src/lib/eltwise_kernel_neon_fp16.cpp @@ -0,0 +1,118 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 element-wise kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N) { + const auto* left_fp16 = reinterpret_cast(left); + const auto* right_fp16 = reinterpret_cast(right); + auto* output_fp16 = reinterpret_cast<_mlas_fp16_*>(output); + + while (N >= 32) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + auto l2 = MlasLoadFloat16x8(left_fp16 + 16); + auto l3 = MlasLoadFloat16x8(left_fp16 + 24); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + auto r2 = MlasLoadFloat16x8(right_fp16 + 16); + auto r3 = MlasLoadFloat16x8(right_fp16 + 24); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + auto o2 = MlasAddFloat16(l2, r2); + auto o3 = MlasAddFloat16(l3, r3); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + MlasStoreFloat16x8(output_fp16 + 16, o2); + MlasStoreFloat16x8(output_fp16 + 24, o3); + + left_fp16 += 32; + right_fp16 += 32; + output_fp16 += 32; + N -= 32; + } + + if (N & 16) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + + left_fp16 += 16; + right_fp16 += 16; + output_fp16 += 16; + N -= 16; + } + + if (N & 8) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto r0 = MlasLoadFloat16x8(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x8(output_fp16, o0); + + left_fp16 += 8; + right_fp16 += 8; + output_fp16 += 8; + N -= 8; + } + + if (N & 4) { + auto l0 = MlasLoadFloat16x4(left_fp16); + auto r0 = MlasLoadFloat16x4(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x4(output_fp16, o0); + + left_fp16 += 4; + right_fp16 += 4; + output_fp16 += 4; + N -= 4; + } + + if (N == 3) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 3); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 3); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 3); + } else if (N == 2) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 2); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 2); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 2); + } else if (N == 1) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 1); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 1); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 1); + } +} + +} // namespace eltwise_neon diff --git a/src/lib/erf.cpp b/src/lib/erf.cpp index b45bd51..f972406 100644 --- a/src/lib/erf.cpp +++ b/src/lib/erf.cpp @@ -22,7 +22,6 @@ Module Name: --*/ #include "mlasi.h" - // // Bundles the constants for use by kernels written in assembly. // @@ -261,7 +260,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) GetMlasPlatform().ErfKernelRoutine(Input, Output, N); #else MlasErfKernel(Input, Output, N); diff --git a/src/lib/fp16_common.h b/src/lib/fp16_common.h index 30b66cd..d4713cc 100644 --- a/src/lib/fp16_common.h +++ b/src/lib/fp16_common.h @@ -27,10 +27,28 @@ typedef float16x8_t MLAS_FLOAT16X8; typedef float16x4_t MLAS_FLOAT16X4; typedef uint16x8_t MLAS_UINT16X8; typedef uint16x4_t MLAS_UINT16X4; +typedef int16x8_t MLAS_INT16X8; +typedef int16x4_t MLAS_INT16X4; MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasReinterpretAsFloat16x8(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } +MlasReinterpretInt32AsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReinterpretInt16AsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasReinterpretInt16AsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); } + +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); } + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); } MLAS_FORCEINLINE MLAS_FLOAT16X8 @@ -64,6 +82,15 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +template +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) { + return vreinterpret_f16_u16( + vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane) + ); +} + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) @@ -95,6 +122,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); } +template +MLAS_FORCEINLINE +void +MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane); +} + MLAS_FORCEINLINE void MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) @@ -125,94 +160,114 @@ MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V) MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vaddq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vadd_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasAddInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vaddq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasAddInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vadd_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasSubtractFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vsubq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasSubtractFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vsub_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasSubtractInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vsubq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasSubtractInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vsub_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiplyFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmulq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiplyFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmul_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasDivFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vdivq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasDivFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vdiv_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Vector3) +MlasMultiplyAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest) { - return vfmaq_f16(Vector3, Vector1, Vector2); + return vfmaq_f16(Dest, Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiplyAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Vector3) +MlasMultiplyAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest) { - return vfma_f16(Vector3, Vector1, Vector2); + return vfma_f16(Dest, Vector1, Vector2); } - MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3) { - MlasMultiplyAddFloat16x8(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); + MlasMultiplyAddFloat16(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); } MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3) { - MlasMultiplyAddFloat16x8(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasDivideFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vdivq_f16(Vector1, Vector2); + MlasMultiplyAddFloat16(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); } MLAS_FORCEINLINE @@ -260,50 +315,127 @@ MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMaximumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmaxq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMaximumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmax_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasMaximumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vmaxq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasMaximumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vmax_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMinimumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vminq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMinimumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmin_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasMinimumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vminq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasMinimumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vmin_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange) { - Value = MlasMaximumFloat16x8(MlasBroadcastFloat16x8(LowerRange), Value); - Value = MlasMinimumFloat16x8(MlasBroadcastFloat16x8(UpperRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(LowerRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(UpperRange), Value); + return Value; +} + +template +MLAS_FORCEINLINE +T +MlasClampFloat16(T Value, T LowerRange, T UpperRange) +{ + Value = MlasMaximumFloat16(LowerRange, Value); + Value = MlasMinimumFloat16(UpperRange, Value); + return Value; +} + +template +MLAS_FORCEINLINE +T +MlasClampInt16(T Value, T LowerRange, T UpperRange) +{ + Value = MlasMaximumInt16(LowerRange, Value); + Value = MlasMinimumInt16(UpperRange, Value); return Value; } MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceAddFloat16x8(MLAS_FLOAT16X8 Vector) +MlasReduceAddFloat16(MLAS_FLOAT16X8 Vector) { + Vector = vpaddq_f16(Vector, Vector); Vector = vpaddq_f16(Vector, Vector); Vector = vpaddq_f16(Vector, Vector); return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); } +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceAddFloat16(MLAS_FLOAT16X4 Vector) +{ + Vector = vpadd_f16(Vector, Vector); + Vector = vpadd_f16(Vector, Vector); + return vget_lane_u16(vreinterpret_u16_f16(Vector), 0); +} + +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceMaximumFloat16(MLAS_FLOAT16X8 Vector) +{ + Vector = vpmaxq_f16(Vector, Vector); + Vector = vpmaxq_f16(Vector, Vector); + Vector = vpmaxq_f16(Vector, Vector); + return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); +} + +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceMaximumFloat16(MLAS_FLOAT16X4 Vector) +{ + Vector = vpmax_f16(Vector, Vector); + Vector = vpmax_f16(Vector, Vector); + return vget_lane_u16(vreinterpret_u16_f16(Vector), 0); +} + MLAS_FORCEINLINE MLAS_UINT16X8 MlasCmpLessEqualFloat16x8(MLAS_FLOAT16X8 left, MLAS_FLOAT16X8 right) @@ -332,4 +464,119 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT return vbsl_f16(select, ones, zeros); } +MLAS_FORCEINLINE +void +Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3, + MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // |v40|v41|v42|v43|v44|v45|v46|v47| + // |v50|v51|v52|v53|v54|v55|v56|v57| + // |v60|v61|v62|v63|v64|v65|v66|v67| + // |v70|v71|v72|v73|v74|v75|v76|v77| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + float16x8x2_t t45 = vtrnq_f16(v4, v5); + float16x8x2_t t67 = vtrnq_f16(v6, v7); + // |v00|v10|v02|v12|v04|v14|v06|v16| + // |v01|v11|v03|v13|v05|v15|v07|v17| + // |v20|v30|v22|v32|v24|v34|v26|v36| + // |v21|v31|v23|v33|v25|v35|v27|v37| + // |v40|v50|v42|v52|v44|v54|v46|v56| + // |v41|v51|v43|v53|v45|v55|v47|v57| + // |v60|v70|v62|v72|v64|v74|v66|v76| + // |v61|v71|v63|v73|v65|v75|v67|v77| + float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])); + float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])); + float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0])); + float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1])); + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + // |v40|v50|v60|v70|v44|v54|v64|v74| + // |v41|v51|v61|v71|v45|v55|v65|v75| + // |v42|v52|v62|v72|v46|v56|v66|v76| + // |v43|v53|v63|v73|v47|v57|v67|v77| + v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + // |v00|v10|v20|v30|v40|v50|v60|v70| + // |v01|v11|v21|v31|v41|v51|v61|v71| + // |v02|v12|v22|v32|v42|v52|v62|v72| + // |v03|v13|v23|v33|v43|v53|v63|v73| + // |v04|v14|v24|v34|v44|v54|v64|v74| + // |v05|v15|v25|v35|v45|v55|v65|v75| + // |v06|v16|v26|v36|v46|v56|v66|v76| + // |v07|v17|v27|v37|v47|v57|v67|v77| +} + +MLAS_FORCEINLINE +void +Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE +void +Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3) +{ + // |v00|v01|v02|v03| + // |v10|v11|v12|v13| + // |v20|v21|v22|v23| + // |v30|v31|v32|v33| + // => + // |v00|v10|v20|v30| + // |v01|v11|v21|v31| + // |v02|v12|v22|v32| + // |v03|v13|v23|v33| + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +template +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasShiftLeftInt16(MLAS_INT16X8 Vector) +{ + return vshlq_n_s16(Vector, ShiftCount); +} + +template +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasShiftLeftInt16(MLAS_INT16X4 Vector) +{ + return vshl_n_s16(Vector, ShiftCount); +} + #endif // fp16 vector intrinsic supported diff --git a/src/lib/halfgemm.cpp b/src/lib/halfgemm.cpp index 49387d2..66a3356 100644 --- a/src/lib/halfgemm.cpp +++ b/src/lib/halfgemm.cpp @@ -324,6 +324,253 @@ MlasHalfGemmKernel( } } +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + auto* dispatch = GetMlasPlatform().HGemmDispatch; + if (TransA == CblasNoTrans && TransB == CblasTrans) { + return dispatch && + dispatch->HGemmKernel_TransposedB && + dispatch->HPackBKernel_TransposedB && + dispatch->HGemmKernel_PackedB; + } else if (TransA == CblasNoTrans && TransB == CblasNoTrans) { + return dispatch && + dispatch->HGemmKernel_B && + dispatch->HPackBKernel_B && + dispatch->HGemmKernel_PackedB; + } + + return false; +} + +void +HGemmOperation( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t K, // full K slice + const MLAS_HGEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { + const size_t lda = DataParams->lda; + const size_t ldb = DataParams->ldb; + const size_t ldc = DataParams->ldc; + const _mlas_fp16_ alpha = DataParams->alpha; + const _mlas_fp16_ beta = DataParams->beta; + auto* dispatch = GetMlasPlatform().HGemmDispatch; + constexpr size_t StrideM = 2; + const auto beta_add = MLAS_FP16(1.0f); + constexpr size_t buffer_size = MLAS_HGEMM_STRIDEN * MLAS_HGEMM_STRIDEK; + + if (TransA == CblasNoTrans && TransB == CblasTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN * ldb; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_TransposedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // When M is small, B is visited once. The overhead of Pack(B') exceeds the benefits + // from A x Pack(B'). Therefore directly calculate A x B'. + // Without PackB, to utilize memory locality, iterate full K. + constexpr size_t StrideN = MLAS_HGEMM_STRIDEN_THREAD_ALIGN; + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + dispatch->HGemmKernel_TransposedB(A, B, C, RangeCountM, countN, K, lda, ldb, ldc, alpha, beta); + B += countN * ldb; + C += countN; + } + } else { + if (!dispatch || !dispatch->HPackBKernel_TransposedB || !dispatch->HGemmKernel_PackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // 16N is the smallest pack unit. + // TODO(fajin): optimize alpha == 1 + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], MLAS_HGEMM_STRIDEN_THREAD_ALIGN * sizeof(_mlas_fp16_)); + size_t StrideN = MLAS_HGEMM_STRIDEN; + size_t StrideK = MLAS_HGEMM_STRIDEK; + if (RangeCountN >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + + } else { + while (StrideN > MLAS_HGEMM_STRIDEN_THREAD_ALIGN && StrideN / 2 >= RangeCountN) { + StrideK *= 2; + StrideN /= 2; + } + } + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_TransposedB(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_PackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK; + } + B += countN * ldb; + C += countN; + } + } + } else if (TransA == CblasNoTrans && TransB == CblasNoTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_B) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x B kernels"); + } + + // When M is small, B is visited once. The overhead of Pack(B) exceeds the benefits + // from A x Pack(B). Therefore directly calculate A x B. + // When beta is 0 or 1, iterate full N and cache accumulators in C. + // When beta is not 0 or 1, iterate full K, accumulat in register, max 8 accumulators. + // TODO(fajin): merge beta cases with alpha == 1 + dispatch->HGemmKernel_B(A, B, C, RangeCountM, RangeCountN, K, lda, ldb, ldc, alpha, beta); + } else { + if (!dispatch || !dispatch->HPackBKernel_B || !dispatch->HGemmKernel_PackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x B kernels"); + } + // TODO(fajin): optimize blocking for large K small N + // - pack along N + // - loop K in outer loop + // - optimize alpha == 1 case + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], MLAS_HGEMM_STRIDEN_THREAD_ALIGN * sizeof(_mlas_fp16_)); + size_t StrideN = MLAS_HGEMM_STRIDEN; + size_t StrideK = MLAS_HGEMM_STRIDEK; + if (RangeCountN >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > MLAS_HGEMM_STRIDEN_THREAD_ALIGN && StrideN / 2 >= RangeCountN) { + StrideK *= 2; + StrideN /= 2; + } + } + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_B(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_PackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK * ldb; + } + B += countN; + C += countN; + } + } + } else { + MLAS_THROW_EX(std::runtime_error, "hgemm currently only support A x Transpoe(B) or A x B"); + } +} + +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) { + if (!ThreadPool) { + for (size_t gemm_i = 0; gemm_i < BatchSize; gemm_i++) { + HGemmOperation(TransA, TransB, K, &Data[gemm_i], 0, M, 0, N); + } + return; + } + + const double Complexity = double(M) * double(N) * double(K) * double(BatchSize); + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_HGEMM_THREAD_COMPLEXITY)) + 1; + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // Segment the operation across multiple threads. + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchSize; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_HGEMM_STRIDEN_THREAD_ALIGN) * MLAS_HGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * static_cast(BatchSize), [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + HGemmOperation(TransA, TransB, K, &Data[gemm_i], RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, diff --git a/src/lib/halfgemm.h b/src/lib/halfgemm.h index 61e2fbb..529db48 100644 --- a/src/lib/halfgemm.h +++ b/src/lib/halfgemm.h @@ -513,3 +513,200 @@ MlasHalfGemmGetDispatch() return &MlasHalfGemmDispatchDefault; #endif } + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HPackB_B_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_B_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_PackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +} // namespace hgemm_neon + +struct MLAS_HGEMM_DISPATCH { + /** + * @brief Pack the B matrix segment. B is column-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 32 columns, then pack CountK rows x 16 columns, then 8. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Column major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + * @param ldb the leading dimension of B. + */ + typedef void(HPackBKernel_TransposedB_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_TransposedB_Fn* HPackBKernel_TransposedB = nullptr; + + /** + * @brief Pack the B matrix segment. B is row-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 32 columns, then pack CountK rows x 16 columns, then 8. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Row major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + * @param ldb the leading dimension of B. + */ + typedef void(HPackBKernel_B_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_B_Fn* HPackBKernel_B = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedB_Fn* HGemmKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * B + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first row of the B matrix segment. Row major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_B_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_B_Fn* HGemmKernel_B = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B has been packed using + * HPackBKernel_TransposedB_Fn or HPackBKernel_B_Fn. Use when M is large. + * + * @param A first row of the A matrix segment. Row major. + * @param PackedB first element of the packed B buffer. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_PackedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_PackedB_Fn* HGemmKernel_PackedB = nullptr; +}; diff --git a/src/lib/halfgemm_kernel_neon_fp16.cpp b/src/lib/halfgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000..959df7f --- /dev/null +++ b/src/lib/halfgemm_kernel_neon_fp16.cpp @@ -0,0 +1,3174 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include + +#include "halfgemm.h" +#include "fp16_common.h" + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + const bool Kr0 = (CountK % 4) > 0; + const bool Kr1 = (CountK % 4) > 1; + const bool Kr2 = (CountK % 4) > 2; + const bool Kr3 = CountK & 4; + for (; CountN >= 32; CountN -= 32, B_data += 32 * ldb) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 32; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + size_t baseb = 0; + size_t basep = 0; + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + MlasStoreFloat16x8(PackedB_data + basep, v0); + MlasStoreFloat16x8(PackedB_data + basep + 32, v1); + MlasStoreFloat16x8(PackedB_data + basep + 64, v2); + MlasStoreFloat16x8(PackedB_data + basep + 96, v3); + MlasStoreFloat16x8(PackedB_data + basep + 128, v4); + MlasStoreFloat16x8(PackedB_data + basep + 160, v5); + MlasStoreFloat16x8(PackedB_data + basep + 192, v6); + MlasStoreFloat16x8(PackedB_data + basep + 224, v7); + v0 = MlasLoadFloat16x8(b + baseb + 8 * ldb); + v1 = MlasLoadFloat16x8(b + baseb + 9 * ldb); + v2 = MlasLoadFloat16x8(b + baseb + 10 * ldb); + v3 = MlasLoadFloat16x8(b + baseb + 11 * ldb); + v4 = MlasLoadFloat16x8(b + baseb + 12 * ldb); + v5 = MlasLoadFloat16x8(b + baseb + 13 * ldb); + v6 = MlasLoadFloat16x8(b + baseb + 14 * ldb); + v7 = MlasLoadFloat16x8(b + baseb + 15 * ldb); + } + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + MlasStoreFloat16x8(PackedB_data + basep, v0); + MlasStoreFloat16x8(PackedB_data + basep + 32, v1); + MlasStoreFloat16x8(PackedB_data + basep + 64, v2); + MlasStoreFloat16x8(PackedB_data + basep + 96, v3); + MlasStoreFloat16x8(PackedB_data + basep + 128, v4); + MlasStoreFloat16x8(PackedB_data + basep + 160, v5); + MlasStoreFloat16x8(PackedB_data + basep + 192, v6); + MlasStoreFloat16x8(PackedB_data + basep + 224, v7); + } + + if (Kr3) { + size_t baseb = 0; + size_t basep = 0; + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + MlasStoreFloat16x4(PackedB_data + basep + 96, v3); + MlasStoreFloat16x4(PackedB_data + basep + 100, v7); + v0 = MlasLoadFloat16x4(b + baseb + 8 * ldb); + v1 = MlasLoadFloat16x4(b + baseb + 9 * ldb); + v2 = MlasLoadFloat16x4(b + baseb + 10 * ldb); + v3 = MlasLoadFloat16x4(b + baseb + 11 * ldb); + v4 = MlasLoadFloat16x4(b + baseb + 12 * ldb); + v5 = MlasLoadFloat16x4(b + baseb + 13 * ldb); + v6 = MlasLoadFloat16x4(b + baseb + 14 * ldb); + v7 = MlasLoadFloat16x4(b + baseb + 15 * ldb); + } + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + MlasStoreFloat16x4(PackedB_data + basep + 96, v3); + MlasStoreFloat16x4(PackedB_data + basep + 100, v7); + k -= 4, b += 4, PackedB_data += 4 * 32; + } + + if (Kr0) { + size_t baseb = 0; + size_t basep = 0; + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + } + v0 = MlasLoadPartialFloat16x4(b + baseb + 8 * ldb, k); + v1 = MlasLoadPartialFloat16x4(b + baseb + 9 * ldb, k); + v2 = MlasLoadPartialFloat16x4(b + baseb + 10 * ldb, k); + v3 = MlasLoadPartialFloat16x4(b + baseb + 11 * ldb, k); + v4 = MlasLoadPartialFloat16x4(b + baseb + 12 * ldb, k); + v5 = MlasLoadPartialFloat16x4(b + baseb + 13 * ldb, k); + v6 = MlasLoadPartialFloat16x4(b + baseb + 14 * ldb, k); + v7 = MlasLoadPartialFloat16x4(b + baseb + 15 * ldb, k); + + } + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + } + PackedB_data += k * 32; + } + } + + if (CountN & 16) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 16; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t v8 = MlasLoadFloat16x8(b + 8 * ldb); + float16x8_t v9 = MlasLoadFloat16x8(b + 9 * ldb); + float16x8_t vA = MlasLoadFloat16x8(b + 10 * ldb); + float16x8_t vB = MlasLoadFloat16x8(b + 11 * ldb); + float16x8_t vC = MlasLoadFloat16x8(b + 12 * ldb); + float16x8_t vD = MlasLoadFloat16x8(b + 13 * ldb); + float16x8_t vE = MlasLoadFloat16x8(b + 14 * ldb); + float16x8_t vF = MlasLoadFloat16x8(b + 15 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + Transpose8x8(v8, v9, vA, vB, vC, vD, vE, vF); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v8); + MlasStoreFloat16x8(PackedB_data + 16, v1); + MlasStoreFloat16x8(PackedB_data + 24, v9); + MlasStoreFloat16x8(PackedB_data + 32, v2); + MlasStoreFloat16x8(PackedB_data + 40, vA); + MlasStoreFloat16x8(PackedB_data + 48, v3); + MlasStoreFloat16x8(PackedB_data + 56, vB); + MlasStoreFloat16x8(PackedB_data + 64, v4); + MlasStoreFloat16x8(PackedB_data + 72, vC); + MlasStoreFloat16x8(PackedB_data + 80, v5); + MlasStoreFloat16x8(PackedB_data + 88, vD); + MlasStoreFloat16x8(PackedB_data + 96, v6); + MlasStoreFloat16x8(PackedB_data + 104, vE); + MlasStoreFloat16x8(PackedB_data + 112, v7); + MlasStoreFloat16x8(PackedB_data + 120, vF); + } + + if (Kr3) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + float16x4_t v8 = MlasLoadFloat16x4(b + 8 * ldb); + float16x4_t v9 = MlasLoadFloat16x4(b + 9 * ldb); + float16x4_t vA = MlasLoadFloat16x4(b + 10 * ldb); + float16x4_t vB = MlasLoadFloat16x4(b + 11 * ldb); + float16x4_t vC = MlasLoadFloat16x4(b + 12 * ldb); + float16x4_t vD = MlasLoadFloat16x4(b + 13 * ldb); + float16x4_t vE = MlasLoadFloat16x4(b + 14 * ldb); + float16x4_t vF = MlasLoadFloat16x4(b + 15 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + MlasStoreFloat16x4(PackedB_data + 48, v3); + MlasStoreFloat16x4(PackedB_data + 52, v7); + MlasStoreFloat16x4(PackedB_data + 56, vB); + MlasStoreFloat16x4(PackedB_data + 60, vF); + + k -= 4, b += 4, PackedB_data += 4 * 16; + } + + if (Kr0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + float16x4_t v8 = MlasLoadPartialFloat16x4(b + 8 * ldb, k); + float16x4_t v9 = MlasLoadPartialFloat16x4(b + 9 * ldb, k); + float16x4_t vA = MlasLoadPartialFloat16x4(b + 10 * ldb, k); + float16x4_t vB = MlasLoadPartialFloat16x4(b + 11 * ldb, k); + float16x4_t vC = MlasLoadPartialFloat16x4(b + 12 * ldb, k); + float16x4_t vD = MlasLoadPartialFloat16x4(b + 13 * ldb, k); + float16x4_t vE = MlasLoadPartialFloat16x4(b + 14 * ldb, k); + float16x4_t vF = MlasLoadPartialFloat16x4(b + 15 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + } + + PackedB_data += k * 16; + } + + CountN -= 16, B_data += 16 * ldb; + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v1); + MlasStoreFloat16x8(PackedB_data + 16, v2); + MlasStoreFloat16x8(PackedB_data + 24, v3); + MlasStoreFloat16x8(PackedB_data + 32, v4); + MlasStoreFloat16x8(PackedB_data + 40, v5); + MlasStoreFloat16x8(PackedB_data + 48, v6); + MlasStoreFloat16x8(PackedB_data + 56, v7); + } + + if (Kr3) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + MlasStoreFloat16x4(PackedB_data + 24, v3); + MlasStoreFloat16x4(PackedB_data + 28, v7); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (Kr0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + } + + PackedB_data += k * 8; + } + + B_data += 8 * ldb; + CountN -= 8; + } + + if (CountN > 0) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack extended 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x8(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x8(); + } + Transpose8x8(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + MlasStoreFloat16x8(PackedB_data, v[0]); + MlasStoreFloat16x8(PackedB_data + 8, v[1]); + MlasStoreFloat16x8(PackedB_data + 16, v[2]); + MlasStoreFloat16x8(PackedB_data + 24, v[3]); + MlasStoreFloat16x8(PackedB_data + 32, v[4]); + MlasStoreFloat16x8(PackedB_data + 40, v[5]); + MlasStoreFloat16x8(PackedB_data + 48, v[6]); + MlasStoreFloat16x8(PackedB_data + 56, v[7]); + } + + if (Kr3) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + MlasStoreFloat16x4(PackedB_data + 24, v[3]); + MlasStoreFloat16x4(PackedB_data + 28, v[7]); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (Kr0) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + } + } + } +} + +void HPackB_B_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + + for (; CountN >= 32; CountN -= 32, B_data += 32) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8x4_t v0 = vld4q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 32) { + vst4q_u16(PackedB_data, v0); + v0 = vld4q_u16(b + ldb); + } + vst4q_u16(PackedB_data, v0); + PackedB_data += 32; + } + + if (CountN & 16) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8x2_t v0 = vld2q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 16) { + vst2q_u16(PackedB_data, v0); + v0 = vld2q_u16(b + ldb); + } + vst2q_u16(PackedB_data, v0); + PackedB_data += 16; + CountN -= 16, B_data += 16; + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8_t v0 = vld1q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 8) { + vst1q_u16(PackedB_data, v0); + v0 = vld1q_u16(b + ldb); + } + vst1q_u16(PackedB_data, v0); + PackedB_data += 8; + + B_data += 8; + CountN -= 8; + } + + if (CountN > 4) { + float16x4_t v0 = MlasLoadFloat16x4(B_data); + float16x4_t v1 = MlasLoadPartialFloat16x4(B_data + 4, CountN - 4); + for (; CountK >= 2; B_data += ldb, PackedB_data += 8, --CountK) { + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v1); + v0 = MlasLoadFloat16x4(B_data + ldb); + v1 = MlasLoadPartialFloat16x4(B_data + ldb + 4, CountN - 4); + } + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v1); + } else if (CountN > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(B_data, CountN); + for (; CountK >= 2; B_data += ldb, PackedB_data += 8, --CountK) { + MlasStoreFloat16x4(PackedB_data, v0); + v0 = MlasLoadPartialFloat16x4(B_data + ldb, CountN); + } + MlasStoreFloat16x4(PackedB_data, v0); + } +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x4(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3) { + v0 = vaddq_f16(v0, v1); + v2 = vaddq_f16(v2, v3); + v0 = vaddq_f16(v0, v2); + return v0; +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x8(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7) { + return vaddq_f16(addq_f16x4(v0, v1, v2, v3), addq_f16x4(v4, v5, v6, v7)); +} + +MLAS_FORCEINLINE +float16x8_t maq_lane_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x4_t a0) { + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + accu0 = vfmaq_lane_f16(accu0, v3, a0, 3); + return accu0; +} + +MLAS_FORCEINLINE +float16x8_t maq_laneq_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7, float16x8_t a0) { + accu0 = vfmaq_laneq_f16(accu0, v0, a0, 0); + accu0 = vfmaq_laneq_f16(accu0, v1, a0, 1); + accu0 = vfmaq_laneq_f16(accu0, v2, a0, 2); + accu0 = vfmaq_laneq_f16(accu0, v3, a0, 3); + accu0 = vfmaq_laneq_f16(accu0, v4, a0, 4); + accu0 = vfmaq_laneq_f16(accu0, v5, a0, 5); + accu0 = vfmaq_laneq_f16(accu0, v6, a0, 6); + accu0 = vfmaq_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_laneq_f16_accu(float16x4_t accu0, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t v4, float16x4_t v5, float16x4_t v6, float16x4_t v7, float16x8_t a0) { + accu0 = vfma_laneq_f16(accu0, v0, a0, 0); + accu0 = vfma_laneq_f16(accu0, v1, a0, 1); + accu0 = vfma_laneq_f16(accu0, v2, a0, 2); + accu0 = vfma_laneq_f16(accu0, v3, a0, 3); + accu0 = vfma_laneq_f16(accu0, v4, a0, 4); + accu0 = vfma_laneq_f16(accu0, v5, a0, 5); + accu0 = vfma_laneq_f16(accu0, v6, a0, 6); + accu0 = vfma_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_lane_f16_accu(float16x4_t accu, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t a0) { + accu = vfma_lane_f16(accu, v0, a0, 0); + accu = vfma_lane_f16(accu, v1, a0, 1); + accu = vfma_lane_f16(accu, v2, a0, 2); + accu = vfma_lane_f16(accu, v3, a0, 3); + return accu; +} + +// beta_behavior: beta == 0.0f16 -> 0, beta == 1.0f16 -> 1, otherwise -> 2 +template +void HGemm_TransposedB_Kernel_Impl( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const bool largeK = CountK >= 8; + const bool Kr0 = (CountK & 3); + const bool Kr1 = (CountK & 3) > 1; + const bool Kr2 = (CountK & 3) > 2; + const bool Kr3 = (CountK & 4); + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu04 = MlasZeroFloat16x8(); + float16x8_t accu05 = MlasZeroFloat16x8(); + float16x8_t accu06 = MlasZeroFloat16x8(); + float16x8_t accu07 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + accu14 = MlasZeroFloat16x8(); + accu15 = MlasZeroFloat16x8(); + accu16 = MlasZeroFloat16x8(); + accu17 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + for (; k >= 16; k -= 8, a += 8, b += 8) { + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + b0 = MlasLoadFloat16x8(b + 8); + b1 = MlasLoadFloat16x8(b + 8 + ldb); + b2 = MlasLoadFloat16x8(b + 8 + 2 * ldb); + b3 = MlasLoadFloat16x8(b + 8 + 3 * ldb); + b4 = MlasLoadFloat16x8(b + 8 + 4 * ldb); + b5 = MlasLoadFloat16x8(b + 8 + 5 * ldb); + b6 = MlasLoadFloat16x8(b + 8 + 6 * ldb); + b7 = MlasLoadFloat16x8(b + 8 + 7 * ldb); + a0 = MlasLoadFloat16x8(a + 8); + } + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + k -= 8, a += 8, b += 8; + } + Transpose8x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + accu00 = addq_f16x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + if constexpr (CountM == 2) { + Transpose8x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + accu10 = addq_f16x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + } + + if (Kr3) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu00 = maq_lane_f16_accu(accu00, v0, v1, v2, v3, a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu10 = maq_lane_f16_accu(accu10, v0, v1, v2, v3, a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu00 = vfmaq_lane_f16(accu00, v0, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu10 = vfmaq_lane_f16(accu10, v0, a1, 0); + } + if (Kr1) { + float16x8_t v1 = vcombine_f16(b1, b5); + accu00 = vfmaq_lane_f16(accu00, v1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, v1, a1, 1); + } + } + if (Kr2) { + float16x8_t v2 = vcombine_f16(b2, b6); + accu00 = vfmaq_lane_f16(accu00, v2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, v2, a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t c0 = MlasLoadFloat16x8(C_data); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + float16x8_t c0 = MlasLoadFloat16x8(C_data); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + for (; k >= 16; k -= 8, a += 8, b += 8) { + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + b0 = MlasLoadFloat16x8(b + 8); + b1 = MlasLoadFloat16x8(b + 8 + ldb); + b2 = MlasLoadFloat16x8(b + 8 + 2 * ldb); + b3 = MlasLoadFloat16x8(b + 8 + 3 * ldb); + a0 = MlasLoadFloat16x8(a + 8); + } + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + k -= 8, a += 8, b += 8; + } + Transpose4x8(accu00, accu01, accu02, accu03); + accu00 = addq_f16x4(accu00, accu01, accu02, accu03); + float16x4_t accu0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)), accu1; + if constexpr (CountM == 2) { + Transpose4x8(accu10, accu11, accu12, accu13); + accu10 = addq_f16x4(accu10, accu11, accu12, accu13); + accu1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + } + + if (Kr3) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = ma_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu1 = ma_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu0 = vfma_lane_f16(accu0, b0, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu1 = vfma_lane_f16(accu1, b0, a1, 0); + } + if (Kr1) { + accu0 = vfma_lane_f16(accu0, b1, a0, 1); + if constexpr (CountM == 2) { + accu1 = vfma_lane_f16(accu1, b1, a1, 1); + } + } + if (Kr2) { + accu0 = vfma_lane_f16(accu0, b2, a0, 2); + if constexpr (CountM == 2) { + accu1 = vfma_lane_f16(accu1, b2, a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t c0 = MlasLoadFloat16x4(C_data); + accu0 = vfma_f16(c0, accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + accu1 = vfma_f16(c1, accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + float16x4_t c0 = MlasLoadFloat16x4(C_data); + accu0 = vfma_f16(vmul_f16(c0, beta_v), accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + accu1 = vfma_f16(vmul_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vmul_f16(accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + accu1 = vmul_f16(accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0[4], accu1[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accu0[i] = MlasZeroFloat16x8(); + if constexpr (CountM == 2) { + accu1[i] = MlasZeroFloat16x8(); + } + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x8(a + lda); + } + for (i = 0; i < CountN; ++i) { + float16x8_t bi = MlasLoadFloat16x8(b + i * ldb); + accu0[i] = vfmaq_f16(accu0[i], bi, a0); + if constexpr (CountM == 2) { + accu1[i] = vfmaq_f16(accu1[i], bi, a1); + } + } + } + Transpose4x8(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x8_t accu00 = addq_f16x4(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x4_t accu_0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)), accu_1; + if constexpr (CountM == 2) { + Transpose4x8(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x8_t accu10 = addq_f16x4(accu1[0], accu1[1], accu1[2], accu1[3]); + accu_1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + } + + if (Kr3) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu_0 = ma_lane_f16_accu(accu_0, bs[0], bs[1], bs[2], bs[3], a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu_1 = ma_lane_f16_accu(accu_1, bs[0], bs[1], bs[2], bs[3], a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu_0 = vfma_lane_f16(accu_0, bs[0], a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu_1 = vfma_lane_f16(accu_1, bs[0], a1, 0); + } + if (Kr1) { + accu_0 = vfma_lane_f16(accu_0, bs[1], a0, 1); + if constexpr (CountM == 2) { + accu_1 = vfma_lane_f16(accu_1, bs[1], a1, 1); + } + } + if (Kr2) { + accu_0 = vfma_lane_f16(accu_0, bs[2], a0, 2); + if constexpr (CountM == 2) { + accu_1 = vfma_lane_f16(accu_1, bs[2], a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + accu_0 = vfma_f16(c0, accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + accu_1 = vfma_f16(c1, accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + accu_0 = vfma_f16(vmul_f16(c0, beta_v), accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + accu_1 = vfma_f16(vmul_f16(c1, beta_v), accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vmul_f16(accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + accu_1 = vmul_f16(accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } + } +} + +// Full K. Directly save to C. +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_Impl<0, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_Impl<1, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_Impl<2, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_Impl<0, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_Impl<1, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_Impl<2, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +// handle C = alpha * A * B + beta * C where alpha != 1 or beta != 0 or 1 +template +void HGemm_B_Kernel_Complicated( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const size_t ldb4 = ldb * 4; + float16x8_t alpha_v8 = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v8 = MlasBroadcastFloat16x8(beta); + float16x4_t alpha_v4 = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v4 = MlasBroadcastFloat16x4(beta); + const bool largeK = CountK >= 4; + const bool Kr0 = CountK & 3; + const bool Kr1 = (CountK & 3 ) > 1; + const bool Kr2 = (CountK & 3 ) > 2; + for (; CountN >= 32; CountN -= 32, B_data += 32, C_data += 32) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + b01 = MlasLoadFloat16x8(b + ldb4 + 8); + b11 = MlasLoadFloat16x8(b + 5 * ldb + 8); + b21 = MlasLoadFloat16x8(b + 6 * ldb + 8); + b31 = MlasLoadFloat16x8(b + 7 * ldb + 8); + b02 = MlasLoadFloat16x8(b + ldb4 + 16); + b12 = MlasLoadFloat16x8(b + 5 * ldb + 16); + b22 = MlasLoadFloat16x8(b + 6 * ldb + 16); + b32 = MlasLoadFloat16x8(b + 7 * ldb + 16); + b03 = MlasLoadFloat16x8(b + ldb4 + 24); + b13 = MlasLoadFloat16x8(b + 5 * ldb + 24); + b23 = MlasLoadFloat16x8(b + 6 * ldb + 24); + b33 = MlasLoadFloat16x8(b + 7 * ldb + 24); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + float16x8_t c01 = MlasLoadFloat16x8(C_data + 8); + float16x8_t c02 = MlasLoadFloat16x8(C_data + 16); + float16x8_t c03 = MlasLoadFloat16x8(C_data + 24); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C_data + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + MlasStoreFloat16x8(C_data + 16, vfmaq_f16(vmulq_f16(c02, beta_v8), accu02, alpha_v8)); + MlasStoreFloat16x8(C_data + 24, vfmaq_f16(vmulq_f16(c03, beta_v8), accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C_data + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C_data + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C_data + ldc + 24); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 16, vfmaq_f16(vmulq_f16(c12, beta_v8), accu12, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 24, vfmaq_f16(vmulq_f16(c13, beta_v8), accu13, alpha_v8)); + } + } + + if (CountN & 16) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + b01 = MlasLoadFloat16x8(b + ldb4 + 8); + b11 = MlasLoadFloat16x8(b + 5 * ldb + 8); + b21 = MlasLoadFloat16x8(b + 6 * ldb + 8); + b31 = MlasLoadFloat16x8(b + 7 * ldb + 8); + + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + float16x8_t c01 = MlasLoadFloat16x8(C_data + 8); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C_data + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C_data + ldc + 8); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + } + + CountN -= 16, B_data += 16, C_data += 16; + } + + if (CountN & 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + } + + CountN -= 8, B_data += 8, C_data += 8; + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x4_t b00 = MlasLoadFloat16x4(b); + float16x4_t b10 = MlasLoadFloat16x4(b + ldb); + float16x4_t b20 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b30 = MlasLoadFloat16x4(b + 3 * ldb); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x4(b + ldb4); + b10 = MlasLoadFloat16x4(b + 5 * ldb); + b20 = MlasLoadFloat16x4(b + 6 * ldb); + b30 = MlasLoadFloat16x4(b + 7 * ldb); + } + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x4_t b00 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x4_t b10 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x4_t b20 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x4_t c00 = MlasLoadFloat16x4(C_data); + MlasStoreFloat16x4(C_data, vfma_f16(vmul_f16(c00, beta_v4), accu00, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c10 = MlasLoadFloat16x4(C_data + ldc); + MlasStoreFloat16x4(C_data + ldc, vfma_f16(vmul_f16(c10, beta_v4), accu10, alpha_v4)); + } + + CountN -= 4, B_data += 4, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x4_t b00 = MlasLoadPartialFloat16x4(b, CountN); + float16x4_t b10 = MlasLoadPartialFloat16x4(b + ldb, CountN); + float16x4_t b20 = MlasLoadPartialFloat16x4(b + 2 * ldb, CountN); + float16x4_t b30 = MlasLoadPartialFloat16x4(b + 3 * ldb, CountN); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadPartialFloat16x4(b + ldb4, CountN); + b10 = MlasLoadPartialFloat16x4(b + 5 * ldb, CountN); + b20 = MlasLoadPartialFloat16x4(b + 6 * ldb, CountN); + b30 = MlasLoadPartialFloat16x4(b + 7 * ldb, CountN); + } + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x4_t b00 = MlasLoadPartialFloat16x4(b, CountN); + accu00 = vfma_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x4_t b10 = MlasLoadPartialFloat16x4(b + ldb, CountN); + accu00 = vfma_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x4_t b20 = MlasLoadPartialFloat16x4(b + 2 * ldb, CountN); + accu00 = vfma_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x4_t c00 = MlasLoadPartialFloat16x4(C_data, CountN); + MlasStorePartialFloat16x4(C_data, vfma_f16(vmul_f16(c00, beta_v4), accu00, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c10 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + MlasStorePartialFloat16x4(C_data + ldc, vfma_f16(vmul_f16(c10, beta_v4), accu10, alpha_v4), CountN); + } + } +} + +// Handle C = A * B + C or C = A * B +template +void HGemm_B_Kernel_Simple( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc +) { + const size_t ldb4 = ldb * 4; + const bool largeN = CountN >= 32; + const bool Nr0 = (CountN % 4) > 0; + const bool Kr1 = (CountK % 4) > 1; + const bool Kr2 = (CountK % 4) > 2; + if constexpr (zero_mode) { + // process first K + if (CountK >= 4) { + float16x4_t a0 = MlasLoadFloat16x4(A_data), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(A_data + lda); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b10 = MlasLoadFloat16x8(b + ldb + 32); + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b30 = MlasLoadFloat16x8(b + 3 * ldb + 32); + b01 = MlasLoadFloat16x8(b + 40); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b31 = MlasLoadFloat16x8(b + 3 * ldb + 40); + b02 = MlasLoadFloat16x8(b + 48); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b32 = MlasLoadFloat16x8(b + 3 * ldb + 48); + b03 = MlasLoadFloat16x8(b + 56); + b13 = MlasLoadFloat16x8(b + ldb + 56); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + b33 = MlasLoadFloat16x8(b + 3 * ldb + 56); + } + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasZeroFloat16x4(); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, n); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasZeroFloat16x4(); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + CountK -= 4, B_data += ldb4, A_data += 4; + } else if (CountK > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(A_data, CountK), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(A_data + lda, CountK); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b10 = MlasZeroFloat16x8(); + float16x8_t b11 = MlasZeroFloat16x8(); + float16x8_t b12 = MlasZeroFloat16x8(); + float16x8_t b13 = MlasZeroFloat16x8(); + float16x8_t b20 = MlasZeroFloat16x8(); + float16x8_t b21 = MlasZeroFloat16x8(); + float16x8_t b22 = MlasZeroFloat16x8(); + float16x8_t b23 = MlasZeroFloat16x8(); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb); + b11 = MlasLoadFloat16x8(b + ldb + 8); + b12 = MlasLoadFloat16x8(b + ldb + 16); + b13 = MlasLoadFloat16x8(b + ldb + 24); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + } + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b01 = MlasLoadFloat16x8(b + 40); + b02 = MlasLoadFloat16x8(b + 48); + b03 = MlasLoadFloat16x8(b + 56); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb + 32); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b13 = MlasLoadFloat16x8(b + ldb + 56); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + } + } + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasZeroFloat16x8(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + } + float16x8_t b0 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasZeroFloat16x4(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + float16x4_t b0 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasZeroFloat16x4(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + + CountK -= CountK, B_data += ldb * CountK, A_data += CountK; + } + } + + for (; CountK >= 4; CountK -= 4, B_data += ldb4, A_data += 4) { + float16x4_t a0 = MlasLoadFloat16x4(A_data), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(A_data + lda); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + float16x8_t accu12 = MlasLoadFloat16x8(c + ldc + 16); + float16x8_t accu13 = MlasLoadFloat16x8(c + ldc + 24); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b10 = MlasLoadFloat16x8(b + ldb + 32); + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b30 = MlasLoadFloat16x8(b + 3 * ldb + 32); + b01 = MlasLoadFloat16x8(b + 40); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b31 = MlasLoadFloat16x8(b + 3 * ldb + 40); + b02 = MlasLoadFloat16x8(b + 48); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b32 = MlasLoadFloat16x8(b + 3 * ldb + 48); + b03 = MlasLoadFloat16x8(b + 56); + b13 = MlasLoadFloat16x8(b + ldb + 56); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + b33 = MlasLoadFloat16x8(b + 3 * ldb + 56); + } + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + float16x8_t accu12 = MlasLoadFloat16x8(c + ldc + 16); + float16x8_t accu13 = MlasLoadFloat16x8(c + ldc + 24); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasLoadFloat16x4(c); + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasLoadFloat16x4(c + ldc); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasLoadPartialFloat16x4(c, n); + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, n); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasLoadPartialFloat16x4(c + ldc, n); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + } + + if (CountK > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(A_data, CountK), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(A_data + lda, CountK); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b10 = MlasZeroFloat16x8(); + float16x8_t b11 = MlasZeroFloat16x8(); + float16x8_t b12 = MlasZeroFloat16x8(); + float16x8_t b13 = MlasZeroFloat16x8(); + float16x8_t b20 = MlasZeroFloat16x8(); + float16x8_t b21 = MlasZeroFloat16x8(); + float16x8_t b22 = MlasZeroFloat16x8(); + float16x8_t b23 = MlasZeroFloat16x8(); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb); + b11 = MlasLoadFloat16x8(b + ldb + 8); + b12 = MlasLoadFloat16x8(b + ldb + 16); + b13 = MlasLoadFloat16x8(b + ldb + 24); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + } + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu12 = MlasLoadFloat16x8(c + ldc + 16); + accu13 = MlasLoadFloat16x8(c + ldc + 24); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b01 = MlasLoadFloat16x8(b + 40); + b02 = MlasLoadFloat16x8(b + 48); + b03 = MlasLoadFloat16x8(b + 56); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb + 32); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b13 = MlasLoadFloat16x8(b + ldb + 56); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + } + } + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu12 = MlasLoadFloat16x8(c + ldc + 16); + accu13 = MlasLoadFloat16x8(c + ldc + 24); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + } + float16x8_t b0 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasLoadFloat16x4(c); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x4(c + ldc); + } + float16x4_t b0 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasLoadPartialFloat16x4(c, n); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadPartialFloat16x4(c + ldc, n); + } + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + + CountK -= CountK, B_data += ldb * CountK, A_data += CountK; + } + } + +void HGemm_B_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (alpha == f16_1.val && beta == f16_0.val) { + HGemm_B_Kernel_Simple<1, true>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else if (alpha == f16_1.val && beta == f16_1.val) { + HGemm_B_Kernel_Simple<1, false>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else { + HGemm_B_Kernel_Complicated<1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } else { + if (alpha == f16_1.val && beta == f16_0.val) { + HGemm_B_Kernel_Simple<2, true>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else if (alpha == f16_1.val && beta == f16_1.val) { + HGemm_B_Kernel_Simple<2, false>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else { + HGemm_B_Kernel_Complicated<2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +// beta_behavior: 0 -> beta == 0, 1 -> beta == 1, 2 -> beta != 0 && beta != 1 +template +void HGemm_PackedB_Kernel_Impl( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const float16x8_t alpha_v8 = MlasBroadcastFloat16x8(alpha); + const float16x8_t beta_v8 = MlasBroadcastFloat16x8(beta); + const float16x4_t alpha_v4 = MlasBroadcastFloat16x4(alpha); + const float16x4_t beta_v4 = MlasBroadcastFloat16x4(beta); + const bool Kr0 = CountK & 3; + const bool Kr1 = (CountK & 3) > 1; + const bool Kr2 = (CountK & 3) > 2; + const bool largeK = CountK >= 4; + for (; CountN >= 32; CountN -= 32, C += 32) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b02 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b12 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b22 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b32 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b03 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b13 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b23 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b33 = MlasLoadFloat16x8(PackedB + 120); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 32) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b00 = MlasLoadFloat16x8(PackedB + 128); + b10 = MlasLoadFloat16x8(PackedB + 160); + b20 = MlasLoadFloat16x8(PackedB + 192); + b30 = MlasLoadFloat16x8(PackedB + 224); + b01 = MlasLoadFloat16x8(PackedB + 136); + b11 = MlasLoadFloat16x8(PackedB + 168); + b21 = MlasLoadFloat16x8(PackedB + 200); + b31 = MlasLoadFloat16x8(PackedB + 232); + b02 = MlasLoadFloat16x8(PackedB + 144); + b12 = MlasLoadFloat16x8(PackedB + 176); + b22 = MlasLoadFloat16x8(PackedB + 208); + b32 = MlasLoadFloat16x8(PackedB + 240); + b03 = MlasLoadFloat16x8(PackedB + 152); + b13 = MlasLoadFloat16x8(PackedB + 184); + b23 = MlasLoadFloat16x8(PackedB + 216); + b33 = MlasLoadFloat16x8(PackedB + 248); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + k -= 4, a += 4, PackedB += 4 * 32; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b02 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b03 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b12 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b13 = MlasLoadFloat16x8(PackedB + 56); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b22 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b23 = MlasLoadFloat16x8(PackedB + 88); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + PackedB += k * 32; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c02 = MlasLoadFloat16x8(C + 16); + float16x8_t c03 = MlasLoadFloat16x8(C + 24); + + MlasStoreFloat16x8(C, vfmaq_f16(c00, accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vfmaq_f16(c01, accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vfmaq_f16(c02, accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vfmaq_f16(c03, accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C + ldc + 24); + MlasStoreFloat16x8(C + ldc, vfmaq_f16(c10, accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vfmaq_f16(c11, accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vfmaq_f16(c12, accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vfmaq_f16(c13, accu13, alpha_v8)); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c02 = MlasLoadFloat16x8(C + 16); + float16x8_t c03 = MlasLoadFloat16x8(C + 24); + + MlasStoreFloat16x8(C, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vfmaq_f16(vmulq_f16(c02, beta_v8), accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vfmaq_f16(vmulq_f16(c03, beta_v8), accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C + ldc + 24); + MlasStoreFloat16x8(C + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vfmaq_f16(vmulq_f16(c12, beta_v8), accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vfmaq_f16(vmulq_f16(c13, beta_v8), accu13, alpha_v8)); + } + } else { + MlasStoreFloat16x8(C, vmulq_f16(accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vmulq_f16(accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vmulq_f16(accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vmulq_f16(accu03, alpha_v8)); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(C + ldc, vmulq_f16(accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vmulq_f16(accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vmulq_f16(accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vmulq_f16(accu13, alpha_v8)); + } + } + } + + if (CountN & 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(), accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 16) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b00 = MlasLoadFloat16x8(PackedB + 64); + b01 = MlasLoadFloat16x8(PackedB + 72); + b10 = MlasLoadFloat16x8(PackedB + 80); + b11 = MlasLoadFloat16x8(PackedB + 88); + b20 = MlasLoadFloat16x8(PackedB + 96); + b21 = MlasLoadFloat16x8(PackedB + 104); + b30 = MlasLoadFloat16x8(PackedB + 112); + b31 = MlasLoadFloat16x8(PackedB + 120); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (Kr0) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + accu00 = vfmaq_f16(c00, accu00, alpha_v8); + accu01 = vfmaq_f16(c01, accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + accu10 = vfmaq_f16(c10, accu10, alpha_v8); + accu11 = vfmaq_f16(c11, accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + accu00 = vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8); + accu01 = vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + accu10 = vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8); + accu11 = vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else { + accu00 = vmulq_f16(accu00, alpha_v8); + accu01 = vmulq_f16(accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v8); + accu11 = vmulq_f16(accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } + + CountN -= 16, C += 16; + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 8) { + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b0 = MlasLoadFloat16x8(PackedB + 32); + b1 = MlasLoadFloat16x8(PackedB + 40); + b2 = MlasLoadFloat16x8(PackedB + 48); + b3 = MlasLoadFloat16x8(PackedB + 56); + } + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + accu00 = vfmaq_f16(c0, accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + accu10 = vfmaq_f16(c1, accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v8), accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v8), accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } else { + accu00 = vmulq_f16(accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(), accu1; + if constexpr (CountM == 2) { + accu1 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 8) { + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b0 = MlasLoadFloat16x8(PackedB + 32); + b1 = MlasLoadFloat16x8(PackedB + 40); + b2 = MlasLoadFloat16x8(PackedB + 48); + b3 = MlasLoadFloat16x8(PackedB + 56); + } + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b2, a1, 2); + } + } + PackedB += k * 8; + } + + float16x4_t accu0_low = vget_low_f16(accu0); + float16x4_t accu0_high = vget_high_f16(accu0); + float16x4_t accu1_low, accu1_high; + if constexpr (CountM == 2) { + accu1_low = vget_low_f16(accu1); + accu1_high = vget_high_f16(accu1); + } + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + MlasStoreFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + MlasStoreFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v4)); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v4), accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + MlasStoreFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v4), accu1_low, alpha_v4)); + } + } else { + MlasStoreFloat16x4(C, vmul_f16(accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v4)); + } + } + CountN -= 4, C += 4; + accu0_low = accu0_high; + if constexpr (CountM == 2) { + accu1_low = accu1_high; + } + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v4), CountN); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v4), accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v4), accu1_low, alpha_v4), CountN); + } + } else { + MlasStorePartialFloat16x4(C, vmul_f16(accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v4), CountN); + } + } + } + } +} + +void HGemm_PackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_PackedB_Kernel only support <= 2 rows"); + } + + const auto* A_data = reinterpret_cast(A); + const auto* PackedB_data = reinterpret_cast(PackedB); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_PackedB_Kernel_Impl<0, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_PackedB_Kernel_Impl<1, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_PackedB_Kernel_Impl<2, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_PackedB_Kernel_Impl<0, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_PackedB_Kernel_Impl<1, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_PackedB_Kernel_Impl<2, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } +} + +} // namespace hgemm_neon diff --git a/src/lib/hgemm_kernel_neon.cpp b/src/lib/hgemm_kernel_neon.cpp new file mode 100644 index 0000000..1531ce9 --- /dev/null +++ b/src/lib/hgemm_kernel_neon.cpp @@ -0,0 +1,30 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){ + MLAS_HGEMM_DISPATCH d; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HPackBKernel_TransposedB = hgemm_neon::HPackB_TransposedB_Kernel; + d.HPackBKernel_B = hgemm_neon::HPackB_B_Kernel; + d.HGemmKernel_TransposedB = hgemm_neon::HGemm_TransposedB_Kernel; + d.HGemmKernel_B = hgemm_neon::HGemm_B_Kernel; + d.HGemmKernel_PackedB = hgemm_neon::HGemm_PackedB_Kernel; +#endif + return d; +}(); diff --git a/src/lib/hqnbitgemm_kernel_neon_fp16.cpp b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000..5b1f9d7 --- /dev/null +++ b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -0,0 +1,865 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hqnbitgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_QNBIT_GEMM_COMPUTE_TYPE HQNBIT_CompFp16. + +--*/ + +#include + +#include +#include +#include + +#include "fp16_common.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ +MLAS_FORCEINLINE void +Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, + uint8x8_t& v4, uint8x8_t& v5, uint8x8_t& v6, uint8x8_t& v7) +{ + // v0: | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // v1: | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // v2: | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // v3: | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // v4: | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // v5: | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // v6: | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // v7: | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + + uint8x8x2_t a0 = vtrn_u8(v0, v1); + uint8x8x2_t a1 = vtrn_u8(v2, v3); + uint8x8x2_t a2 = vtrn_u8(v4, v5); + uint8x8x2_t a3 = vtrn_u8(v6, v7); + + // a0[0]: | B00 B10 | B01 B11 | B40 B50 | B41 B51 | B80 B90 | B81 B91 | Bc0 Bd0 | Bc1 Bd1 | + // a0[1]: | B20 B30 | B21 B31 | B60 B70 | B61 B71 | Ba0 Bb0 | Ba1 Bb1 | Be0 Bf0 | Be1 Bf1 | + // a1[0]: | B02 B12 | B03 B13 | B42 B52 | B43 B53 | B82 B92 | B83 B93 | Bc2 Bd2 | Bc3 Bd3 | + // a1[1]: | B22 B32 | B23 B33 | B62 B72 | B63 B73 | Ba2 Bb2 | Ba3 Bb3 | Be2 Bf2 | Be3 Bf3 | + // a2[0]: | B04 B14 | B05 B15 | B44 B54 | B45 B55 | B84 B94 | B85 B95 | Bc4 Bd4 | Bc5 Bd5 | + // a2[1]: | B24 B34 | B25 B35 | B64 B74 | B65 B75 | Ba4 Bb4 | Ba5 Bb5 | Be4 Bf4 | Be5 Bf5 | + // a3[0]: | B06 B16 | B07 B17 | B46 B56 | B47 B57 | B86 B96 | B87 B97 | Bc6 Bd6 | Bc7 Bd7 | + // a3[1]: | B26 B36 | B27 B37 | B66 B76 | B67 B77 | Ba6 Bb6 | Ba7 Bb7 | Be6 Bf6 | Be7 Bf7 | + + uint16x4x2_t b0 = vtrn_u16(vreinterpret_u16_u8(a0.val[0]), vreinterpret_u16_u8(a1.val[0])); + uint16x4x2_t b1 = vtrn_u16(vreinterpret_u16_u8(a0.val[1]), vreinterpret_u16_u8(a1.val[1])); + uint16x4x2_t b2 = vtrn_u16(vreinterpret_u16_u8(a2.val[0]), vreinterpret_u16_u8(a3.val[0])); + uint16x4x2_t b3 = vtrn_u16(vreinterpret_u16_u8(a2.val[1]), vreinterpret_u16_u8(a3.val[1])); + + // b0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B80 B90 | B81 B91 | B82 B92 | B83 B93 | + // b0[1]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | + // b1[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | + // b1[1]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | + // b2[0]: | B04 B14 | B05 B15 | B06 B16 | B07 B17 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // b2[1]: | B44 B54 | B45 B55 | B46 B56 | B47 B57 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // b3[0]: | B24 B34 | B25 B35 | B26 B36 | B27 B37 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // b3[1]: | B64 B74 | B65 B75 | B66 B76 | B67 B77 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), vreinterpret_u32_u16(b2.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), vreinterpret_u32_u16(b2.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b1.val[0]), vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b1.val[1]), vreinterpret_u32_u16(b3.val[1])); + + // c0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // c0[1]: | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // c1[0]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // c1[1]: | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // c2[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // c2[1]: | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // c3[0]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // c3[1]: | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + v0 = vreinterpret_u8_u32(c0.val[0]); + v1 = vreinterpret_u8_u32(c2.val[0]); + v2 = vreinterpret_u8_u32(c1.val[0]); + v3 = vreinterpret_u8_u32(c3.val[0]); + v4 = vreinterpret_u8_u32(c0.val[1]); + v5 = vreinterpret_u8_u32(c2.val[1]); + v6 = vreinterpret_u8_u32(c1.val[1]); + v7 = vreinterpret_u8_u32(c3.val[1]); +} + +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + constexpr size_t nbits = 4; + constexpr size_t k_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % k_blk_dim == 0); + + const size_t k_blk_num = MlasDivRoundup(K, k_blk_dim); + const size_t n_blk_num = MlasDivRoundup(N, n_blk_dim); + constexpr size_t k_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, k_blk_dim); + const size_t iterations = k_blk_num * n_blk_num; // one iteration per block + const size_t ld = MlasDivRoundup(K, BlkLen) * MlasQNBitBlkDataSizeInBytes(nbits, BlkLen); + + // + // For blocks 16_K * 8_N, transpose bytes in 8x8 blocks like this: + // src B_k_n: + // | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + // => dst: + // | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + // + + // + // For blocks < 8_N: + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + MlasTrySimpleParallel( + ThreadPool, iterations, + [&](ptrdiff_t tid) { + const size_t n_blk = tid / k_blk_num; + const size_t k_blk = tid % k_blk_num; + size_t n = n_blk * n_blk_dim; + const size_t src_offset = n * ld + k_blk * k_blk_bytes; + + if (n + n_blk_dim <= N) { + const size_t dst_offset = n * ld + k_blk * k_blk_bytes * n_blk_dim; + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + dst_offset; + + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v1 = vld1_u8(src + ld); + uint8x8_t v2 = vld1_u8(src + 2*ld); + uint8x8_t v3 = vld1_u8(src + 3*ld); + uint8x8_t v4 = vld1_u8(src + 4*ld); + uint8x8_t v5 = vld1_u8(src + 5*ld); + uint8x8_t v6 = vld1_u8(src + 6*ld); + uint8x8_t v7 = vld1_u8(src + 7*ld); + + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + vst1_u8(dst, v0); + vst1_u8(dst + 8, v1); + vst1_u8(dst + 16, v2); + vst1_u8(dst + 24, v3); + vst1_u8(dst + 32, v4); + vst1_u8(dst + 40, v5); + vst1_u8(dst + 48, v6); + vst1_u8(dst + 56, v7); + } else { + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + src_offset; + + for (; n < N; ++n, src += ld, dst += ld) { + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v_even = vand_u8(v0, vdup_n_u8(0x0F)); + uint8x8_t v_odd = vshr_n_u8(v0, 4); + uint8x8x2_t v1 = vzip_u8(v_even, v_odd); + uint8x8_t v2 = vorr_u8(v1.val[0], vshl_n_u8(v1.val[1], 4)); + vst1_u8(dst, v2); + } + } + } + ); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t b01 = vld1_u8(src_ptr); + uint8x8_t b23 = vld1_u8(src_ptr + 8); + uint8x8_t b45 = vld1_u8(src_ptr + 16); + uint8x8_t b67 = vld1_u8(src_ptr + 24); + uint8x8_t b89 = vld1_u8(src_ptr + 32); + uint8x8_t bab = vld1_u8(src_ptr + 40); + uint8x8_t bcd = vld1_u8(src_ptr + 48); + uint8x8_t bef = vld1_u8(src_ptr + 56); + + float16x8_t b0 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b01, low_mask), 0)); + float16x8_t b1 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b01, 4), 0)); + float16x8_t b2 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b23, low_mask), 0)); + float16x8_t b3 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b23, 4), 0)); + float16x8_t b4 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b45, low_mask), 0)); + float16x8_t b5 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b45, 4), 0)); + float16x8_t b6 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b67, low_mask), 0)); + float16x8_t b7 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b67, 4), 0)); + float16x8_t b8 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b89, low_mask), 0)); + float16x8_t b9 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b89, 4), 0)); + float16x8_t ba = vcvtq_f16_u16(vshll_n_u8(vand_u8(bab, low_mask), 0)); + float16x8_t bb = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bab, 4), 0)); + float16x8_t bc = vcvtq_f16_u16(vshll_n_u8(vand_u8(bcd, low_mask), 0)); + float16x8_t bd = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bcd, 4), 0)); + float16x8_t be = vcvtq_f16_u16(vshll_n_u8(vand_u8(bef, low_mask), 0)); + float16x8_t bf = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bef, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, b0, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, b1, scale); + float16x8_t c2 = vfmaq_f16(neg_scaled_zp, b2, scale); + float16x8_t c3 = vfmaq_f16(neg_scaled_zp, b3, scale); + float16x8_t c4 = vfmaq_f16(neg_scaled_zp, b4, scale); + float16x8_t c5 = vfmaq_f16(neg_scaled_zp, b5, scale); + float16x8_t c6 = vfmaq_f16(neg_scaled_zp, b6, scale); + float16x8_t c7 = vfmaq_f16(neg_scaled_zp, b7, scale); + float16x8_t c8 = vfmaq_f16(neg_scaled_zp, b8, scale); + float16x8_t c9 = vfmaq_f16(neg_scaled_zp, b9, scale); + float16x8_t ca = vfmaq_f16(neg_scaled_zp, ba, scale); + float16x8_t cb = vfmaq_f16(neg_scaled_zp, bb, scale); + float16x8_t cc = vfmaq_f16(neg_scaled_zp, bc, scale); + float16x8_t cd = vfmaq_f16(neg_scaled_zp, bd, scale); + float16x8_t ce = vfmaq_f16(neg_scaled_zp, be, scale); + float16x8_t cf = vfmaq_f16(neg_scaled_zp, bf, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); + MlasStoreFloat16x8(dst_ptr + 16, c2); + MlasStoreFloat16x8(dst_ptr + 24, c3); + MlasStoreFloat16x8(dst_ptr + 32, c4); + MlasStoreFloat16x8(dst_ptr + 40, c5); + MlasStoreFloat16x8(dst_ptr + 48, c6); + MlasStoreFloat16x8(dst_ptr + 56, c7); + MlasStoreFloat16x8(dst_ptr + 64, c8); + MlasStoreFloat16x8(dst_ptr + 72, c9); + MlasStoreFloat16x8(dst_ptr + 80, ca); + MlasStoreFloat16x8(dst_ptr + 88, cb); + MlasStoreFloat16x8(dst_ptr + 96, cc); + MlasStoreFloat16x8(dst_ptr + 104, cd); + MlasStoreFloat16x8(dst_ptr + 112, ce); + MlasStoreFloat16x8(dst_ptr + 120, cf); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 1 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t v0 = vld1_u8(src_ptr); + + float16x8_t f_low = vcvtq_f16_u16(vshll_n_u8(vand_u8(v0, low_mask), 0)); + float16x8_t f_high = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(v0, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, f_low, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, f_high, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); +} + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +) { + MLAS_UNREFERENCED_PARAMETER(K); + constexpr size_t nbits = 4; + constexpr size_t kk_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % kk_blk_dim == 0); + + const size_t kk_blk_num = BlockCountK * BlkLen / kk_blk_dim; + constexpr size_t kk_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, kk_blk_dim); + const size_t kk_n_src_bytes = kk_blk_bytes * n_blk_dim; + const size_t kk_n_dst_size = kk_blk_dim * n_blk_dim; + const size_t ld_blk_src = kk_blk_num * kk_n_src_bytes; + const size_t ld_blk_dst = BlkLen * BlockCountK * n_blk_dim; + const size_t ld_blk_scale = BlockCountK * n_blk_dim; + const size_t ld_zp = (BlockCountK + 1) / 2; + const size_t ld_blk_zp = ld_zp * n_blk_dim; + const float16x8_t zp_mid_point_vec = MlasBroadcastFloat16x8(MLAS_FP16(8.0f).val); + const bool has_zp = QuantBZeroPoint != nullptr; + + size_t n = 0; + for (; n + n_blk_dim <= CountN; n += n_blk_dim) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + const std::uint8_t* src_ptr = reinterpret_cast(QuantBData); + auto* dst_ptr = reinterpret_cast<_mlas_fp16_*>(FpData); + + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + // prepare scales and zero_points for the block + _mlas_fp16_ scales[n_blk_dim]; + uint16_t zero_points[n_blk_dim]; + float16x8_t scale_vec; + float16x8_t neg_scaled_zp_vec; + + UnrolledLoop([&](int nn){ + scales[nn] = scales_ptr[nn * BlockCountK]; + }); + scale_vec = MlasLoadFloat16x8(scales); + + if (has_zp) { + UnrolledLoop([&](int nn){ + uint8_t zp = zero_points_ptr[nn * ld_zp]; + zp = (k_blk_i & 1) ? (zp >> 4) : (zp & 0x0F); + zero_points[nn] = static_cast(zp); + }); + uint16x8_t zp_u16_vec = vld1q_u16(zero_points); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<8, 16>(src_ptr, scale_vec, neg_scaled_zp_vec, dst_ptr); + + src_ptr += kk_n_src_bytes; + dst_ptr += kk_n_dst_size; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBData += ld_blk_src; + FpData += ld_blk_dst; + QuantBScale += ld_blk_scale; + QuantBZeroPoint = has_zp ? QuantBZeroPoint + ld_blk_zp : nullptr; + } + + // remaining N + for (; n < CountN; ++n) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + const auto scale = scales_ptr[0]; + float16x8_t scale_vec = MlasBroadcastFloat16x8(scale); + float16x8_t neg_scaled_zp_vec; + + if (has_zp) { + uint8_t zero_point = static_cast(zero_points_ptr[0]); + zero_point = (k_blk_i & 1) ? (zero_point >> 4) : (zero_point & 0x0F); + uint16x8_t zp_u16_vec = vdupq_n_u16(static_cast(zero_point)); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<1, 16>( + reinterpret_cast(QuantBData), scale_vec, neg_scaled_zp_vec, + reinterpret_cast<_mlas_fp16_*>(FpData) + ); + + QuantBData += kk_blk_bytes; + FpData += kk_blk_dim; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBScale += BlockCountK; + if (has_zp) { + QuantBZeroPoint += ld_zp; + } + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8), float16x8_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x8(Bias); + } else { + return MlasZeroFloat16x8(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 4), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x4(Bias); + } else { + return MlasZeroFloat16x4(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N == 2 || N == 1)), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + float16x4_t v = MlasZeroFloat16x4(); + + if (Bias) { + v = MlasLoadLaneFloat16x4<0>(Bias, v); + if constexpr (N == 2) { + v = MlasLoadLaneFloat16x4<1>(Bias + 1, v); + } + return v; + } else { + return v; + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 8), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x8_t a0 = MlasLoadFloat16x8(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + float16x8_t b4 = MlasLoadFloat16x8(B + 32); + float16x8_t b5 = MlasLoadFloat16x8(B + 40); + float16x8_t b6 = MlasLoadFloat16x8(B + 48); + float16x8_t b7 = MlasLoadFloat16x8(B + 56); + + // This version uses less instructions, but introduces dependency path between instructions. + // Must pair it with loop unrolling to alleviate dependency path penalty. + float16x8_t c0 = vfmaq_laneq_f16(accumulator, b0, a0, 0); + c0 = vfmaq_laneq_f16(c0, b1, a0, 1); + c0 = vfmaq_laneq_f16(c0, b2, a0, 2); + c0 = vfmaq_laneq_f16(c0, b3, a0, 3); + c0 = vfmaq_laneq_f16(c0, b4, a0, 4); + c0 = vfmaq_laneq_f16(c0, b5, a0, 5); + c0 = vfmaq_laneq_f16(c0, b6, a0, 6); + c0 = vfmaq_laneq_f16(c0, b7, a0, 7); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 4), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0); + c0 = vfmaq_lane_f16(c0, b1, a0, 1); + c0 = vfmaq_lane_f16(c0, b2, a0, 2); + c0 = vfmaq_lane_f16(c0, b3, a0, 3); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && (K == 2 || K == 1)), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasZeroFloat16x4(); + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + if constexpr (K == 2) a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + float16x8_t b0 = MlasLoadFloat16x8(B), b1; + if constexpr (K == 2) b1 = MlasLoadFloat16x8(B + 8); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0), c01; + if constexpr (K == 2) c01 = vfmaq_lane_f16(c0, b1, a0, 1); + + if constexpr (K == 1) + return c0; + else + return c01; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && K == 8), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x8_t a0 = MlasLoadFloat16x8(A); + + float16x8_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x8(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x8(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x8(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x8(B + ldb * 3); + + float16x8_t c00, c01, c02, c03; + c00 = vmulq_f16(b0, a0); + if constexpr (N > 1) + c01 = vmulq_f16(b1, a0); + else + c01 = MlasZeroFloat16x8(); + if constexpr (N > 2) + c02 = vmulq_f16(b2, a0); + else + c02 = MlasZeroFloat16x8(); + if constexpr (N > 3) + c03 = vmulq_f16(b3, a0); + else + c03 = MlasZeroFloat16x8(); + + Transpose4x8(c00, c01, c02, c03); + + float16x8_t c_low_high = vaddq_f16(vaddq_f16(c00, c01), vaddq_f16(c02, c03)); + float16x4_t c_low = vget_low_f16(c_low_high); + float16x4_t c_high = vget_high_f16(c_low_high); + float16x4_t c = vadd_f16(c_low, c_high); + + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K == 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x4_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x4(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x4(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x4(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x4(B + ldb * 3); + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K > 0 && K < 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasZeroFloat16x4(), b1, b2, b3; + if constexpr (N > 1) b1 = MlasZeroFloat16x4(); + if constexpr (N > 2) b2 = MlasZeroFloat16x4(); + if constexpr (N > 3) b3 = MlasZeroFloat16x4(); + + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + b0 = MlasLoadLaneFloat16x4<0>(B, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<0>(B + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<0>(B + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<0>(B + ldb * 3, b3); + + if constexpr (K >= 2) { + a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + b0 = MlasLoadLaneFloat16x4<1>(B + 1, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 3, b3); + } + + if constexpr (K >= 3) { + a0 = MlasLoadLaneFloat16x4<2>(A + 2, a0); + b0 = MlasLoadLaneFloat16x4<2>(B + 2, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 3, b3); + } + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +typename std::enable_if_t<((CountN >= 1 && CountN <= 16 && ((CountN - 1) & CountN) == 0) && (CountM == 1 || CountM == 2)), void> +HQ4BitGemmKernel_CompFp16_Kernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const _mlas_fp16_* Bias, + _mlas_fp16_* C, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + using RegisterType = typename std::conditional_t<(CountN < 8), float16x4_t, float16x8_t>; + + RegisterType accu00, accu01, accu10, accu11; + constexpr size_t b_step = CountN >= 8 ? 8 : 1; + constexpr size_t N = CountN == 16 ? 8 : CountN; + + if constexpr (CountM == 2) { + accu00 = accu10 = PrepareAccumulator(Bias); + } else { + accu00 = PrepareAccumulator(Bias); + } + if constexpr (CountN == 16) { + if constexpr (CountM == 2) { + accu01 = accu11 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } else { + accu01 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } + } + + size_t k = 0; + for (; k + 8 <= K; k += 8, A += 8, B += b_step * 8) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if (K & 4) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 4, A += 4, B += b_step * 4; + } + + if (K & 2) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 2, A += 2, B += b_step * 2; + } + + if (k < K) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C, accu00); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + 8, accu01); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C, accu00); + } else { + MlasStoreLaneFloat16x4<0>(C, accu00); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + 1, accu00); + } + } + + if constexpr (CountM == 2) { + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C + ldc, accu10); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C + ldc, accu10); + } else { + MlasStoreLaneFloat16x4<0>(C + ldc, accu10); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + ldc + 1, accu10); + } + } + } +} + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + assert(CountM <= 2); + + // 2M_16N is the balance between loop unrolling and register spill. + // More unroll will trigger register spill. + // Less unroll will increase micro kernel dependency path penalty. + // TODO: dequant 16N as continuous segments. Current version dequants 8N. + const auto* a = reinterpret_cast(A); + const auto* b = reinterpret_cast(B); + const auto* bias = reinterpret_cast(Bias); + auto* c = reinterpret_cast<_mlas_fp16_*>(C); + + for (; CountN >= 16; CountN -= 16) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<16, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<16, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 16 * ldb, c += 16; + if (bias) bias += 16; + } + + if (CountN & 8) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<8, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<8, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 8 * ldb, c += 8; + if (bias) bias += 8; + } + + if (CountN & 4) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<4, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<4, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 4 * ldb, c += 4; + if (bias) bias += 4; + } + + if (CountN & 2) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<2, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<2, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 2 * ldb, c += 2; + if (bias) bias += 2; + } + + if (CountN & 1) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<1, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<1, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + } +} +} // namespace sqnbitgemm_neon diff --git a/src/lib/intrinsics/avx2/saturation_check_avx2.cpp b/src/lib/intrinsics/avx2/saturation_check_avx2.cpp new file mode 100644 index 0000000..5ff4c0e --- /dev/null +++ b/src/lib/intrinsics/avx2/saturation_check_avx2.cpp @@ -0,0 +1,62 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check_avx2.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include + +#include +#include + +namespace onnxruntime +{ +extern std::atomic saturation_count; +} + +extern "C" void +CheckSaturationForVPMADDUBSW(const __m256i* unsigned_ptr, const __m256i* signed_ptr) +{ + // Load data from memory (unaligned load) + __m256i unsigned_data = _mm256_loadu_si256(unsigned_ptr); + __m256i signed_data = _mm256_loadu_si256(signed_ptr); + + alignas(32) uint8_t unsigned_bytes[32]; // Unsigned input values + alignas(32) int8_t signed_bytes[32]; // Signed input values + + // Store the data into the byte arrays + _mm256_store_si256(reinterpret_cast<__m256i*>(unsigned_bytes), unsigned_data); + _mm256_store_si256(reinterpret_cast<__m256i*>(signed_bytes), signed_data); + + bool saturation_detected = false; + + // Iterate through the 16 pairs of 8-bit unsigned and signed values + for (int i = 0; i < 16; ++i) { + // Perform the VPMADDUBSW operation in higher precision (int32_t) + int32_t computed_value = + static_cast(signed_bytes[2 * i]) * static_cast(static_cast(unsigned_bytes[2 * i])) + + static_cast(signed_bytes[2 * i + 1]) * static_cast(static_cast(unsigned_bytes[2 * i + 1])); + + // If the computed value exceeds the 16-bit signed integer range, saturation occurred + if (computed_value > INT16_MAX || computed_value < INT16_MIN) { + saturation_detected = true; + break; + } + } + + // If saturation is detected, log a warning (only log once based on the atomic count) + if (saturation_detected && ++onnxruntime::saturation_count < 2) { + std::cerr << "Warning: saturation detected in VPMADDUBSW instruction." << std::endl; + } +} diff --git a/src/lib/kai_ukernel_interface.cpp b/src/lib/kai_ukernel_interface.cpp new file mode 100644 index 0000000..fdada83 --- /dev/null +++ b/src/lib/kai_ukernel_interface.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai_ukernel_interface.h" +#include "mlasi.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm = + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; + } else { + return kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod; + } +} + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + return kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod; + } else { + return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; + } +} diff --git a/src/lib/kai_ukernel_interface.h b/src/lib/kai_ukernel_interface.h new file mode 100644 index 0000000..1a6f111 --- /dev/null +++ b/src/lib/kai_ukernel_interface.h @@ -0,0 +1,12 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); diff --git a/src/lib/kleidiai/convolve_kleidiai.cpp b/src/lib/kleidiai/convolve_kleidiai.cpp new file mode 100644 index 0000000..9eaf490 --- /dev/null +++ b/src/lib/kleidiai/convolve_kleidiai.cpp @@ -0,0 +1,720 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "mlasi_kleidiai.h" +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" + +// Right-hand-side (weights) cache key +struct RhsCacheKey { + size_t co, ci, kh, kw, dilationh, dilationw; + size_t weights_hash; + + bool operator==(const RhsCacheKey& other) const { + return co == other.co && ci == other.ci && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + weights_hash == other.weights_hash; + } +}; + + +// Left-hand-side (input indirection) cache key +struct LhsCacheKey { + size_t ci, ih, iw; + size_t padding, sh, sw; + size_t kh, kw; + size_t dilationh, dilationw; + size_t data_hash; + + bool operator==(const LhsCacheKey& other) const { + return ci == other.ci && ih == other.ih && iw == other.iw && + padding == other.padding && sh == other.sh && sw == other.sw && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + data_hash == other.data_hash; + } +}; + +// Derived from 2^32 * (sqrt(5) - 1) / 2 ≈ 0.6180339887 (reciprocal of the golden ratio) +// Based on Knuth's multiplicative hashing method +constexpr size_t HASH_GOLDEN_RATIO_CONST = 0x9e3779b9; + +size_t HashWeights(const float* data, size_t count = 16) { + size_t h = 0; + for (size_t i = 0; i < count; ++i) { + h ^= std::hash()(data[i]) + HASH_GOLDEN_RATIO_CONST + (h << 6) + (h >> 2); + } + return h; +} + +namespace std { + // Specialize hash type for cache keys and do it within namespace std. + // Doing this allows standard containers like std::unordered_map to find + // the appropriate hash function via template specialization, as ADL + // (argument-dependent lookup) does not apply to std::hash. + template<> + struct hash { + size_t operator()(const RhsCacheKey& k) const { + return k.weights_hash ^ + (std::hash()(k.co) << 1) ^ + (std::hash()(k.ci) << 2) ^ + (std::hash()(k.kh) << 3) ^ + (std::hash()(k.kw) << 4) ^ + (std::hash()(k.dilationh) << 5) ^ + (std::hash()(k.dilationw) << 6); + } + }; + + template<> + struct hash { + size_t operator()(const LhsCacheKey& k) const { + return k.data_hash ^ + (std::hash()(k.ci) << 1) ^ + (std::hash()(k.ih) << 2) ^ + (std::hash()(k.iw) << 3) ^ + (std::hash()(k.padding) << 4) ^ + (std::hash()(k.sh) << 5) ^ + (std::hash()(k.sw) << 6) ^ + (std::hash()(k.kh) << 7) ^ + (std::hash()(k.kw) << 8) ^ + (std::hash()(k.dilationh) << 9) ^ + (std::hash()(k.dilationw) << 10); + } + }; + +} + + +static constexpr size_t ComputeKernelSize(const size_t D, const size_t K) { + // D - dilation size + // K - kernel size + + // D*S scale 1D kernel dimension by dilation factor + // (D-1) remove affect of dilation scaling at kernel end + return (D*K) - (D - 1); +} + +static constexpr size_t ComputeConvOutSize(const size_t L, const size_t K, const size_t P, const size_t S) { + + //With start + end padding + + //L - input size + //K - kernel size + //P - Padding size + //S - stride size + + //Does the convolution compute one value or less ? + if ( S > 0 && (L + 2*P) >= K) { + // L-(K-1) standard convolution output size is L-(K-1) for a step size of 1 with no padding + // (2*P) 1D start and end padding + // (L+2*P)-(K-1) the 1D length of convolution result for a kernel step size of 1 + // /S apply the kernel step + return (((L - K) + (2 * P)) / S) + 1; + } + return 0; +} + +static size_t ComputeMlasWorkingBufferSize(const size_t co, + const size_t ih, const size_t iw, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const size_t sh, const size_t sw, + const size_t padding) { + // dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + return m * co; +} + +static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { + + //functional checks - logically can the conv be performed + if ((Parameters->Dimensions != 2) || + (Parameters->BatchCount != 1) || + (Parameters->Beta != 0.f) || + (Parameters->Padding[0] != Parameters->Padding[1]) || + (Parameters->Padding[0] != Parameters->Padding[2]) || + (Parameters->Padding[0] != Parameters->Padding[3]) || + (ComputeConvOutSize(Parameters->InputShape[0], + ComputeKernelSize(Parameters->DilationShape[0],Parameters->KernelShape[0]), + Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], + ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), + Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { + return false; + } + + //optimization checks - is the implementation optimal for the conv request + + const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0], + Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1], + Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]); + auto N = Parameters->FilterCount; + auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1]; + + //Can use these variables to add other conditions as required + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(m_step); + MLAS_UNREFERENCED_PARAMETER(n_step); + + if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + return false; + } + return true; +} + +//General purpose axis swapping +static auto Transpose4D(std::array shape_in, + const float* in, + std::array permute) { + + std::array shape_out{shape_in[permute[0]], + shape_in[permute[1]], + shape_in[permute[2]], + shape_in[permute[3]]}; + + assert((shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]) == + (shape_out[0] * shape_out[1] * shape_out[2] * shape_out[3])); + assert(permute[0] < 4 && permute[1] < 4 && permute[2] < 4 && permute[3] < 4); + + const size_t get_stride[] {shape_in[1] * shape_in[2] * shape_in[3], shape_in[2] * shape_in[3], shape_in[3]}; + auto get = [get_stride,in](const std::array& el) { + return in[el[0] * get_stride[0] + + el[1] * get_stride[1] + + el[2] * get_stride[2] + + el[3]]; + }; + + auto out_ = std::make_unique(shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]); + auto out = out_.get(); + + const size_t set_stride[]{shape_out[1] * shape_out[2] * shape_out[3], shape_out[2] * shape_out[3], shape_out[3]}; + auto set = [set_stride,out](const std::array& el, float v) { + out[el[0] * set_stride[0] + + el[1] * set_stride[1] + + el[2] * set_stride[2] + + el[3]] = v; + }; + + std::array shape; + for (shape[0] = 0; shape[0] < shape_in[0]; ++shape[0]) { + for (shape[1] = 0; shape[1] < shape_in[1]; ++shape[1]) { + for (shape[2] = 0; shape[2] < shape_in[2]; ++shape[2]) { + for (shape[3] = 0; shape[3] < shape_in[3]; ++shape[3]) { + set({shape[permute[0]], shape[permute[1]], shape[permute[2]], shape[permute[3]]}, get(shape)); + } + } + } + } + + return out_; +} + +//nchw to nhwc specific axis swapping +static std::unique_ptr NChwToNhwc(const size_t n, + const size_t c, + const size_t h, + const size_t w, + const float* RESTRICT in, + const size_t dilationh=1, + const size_t dilationw=1, + const bool zero_fill=false, + MLAS_THREADPOOL* ThreadPool=nullptr) { + + const auto d_h = ComputeKernelSize(dilationh, h); + const auto d_w = ComputeKernelSize(dilationw, w); + + auto t = std::make_unique(n*d_h*d_w*c); + if (zero_fill) { + std::fill(&t.get()[0], &t.get()[n*d_h*d_w*c], 0.f); + } + + if (dilationh > 1 || dilationw > 1 || n > 1) { + const size_t get_strides[] {c*h*w,h*w,w}; + auto get = [get_strides,in](const std::array& el) { + return in[el[0]*get_strides[0] + + el[1]*get_strides[1] + + el[2]*get_strides[2] + + el[3]]; + }; + + const size_t set_strides[] {d_h*d_w*c,dilationh*d_w*c,dilationw*c}; + auto set = [set_strides](const std::array& el, float v, float* out) { + out[el[0]*set_strides[0] + + el[1]*set_strides[1] + + el[2]*set_strides[2] + + el[3]] = v; + }; + + MLAS_UNREFERENCED_PARAMETER(set); + MLAS_UNREFERENCED_PARAMETER(get); + + auto out0 = t.get(); + for (size_t s0 = n; s0 > 0; --s0) { + auto out1 = out0; + for (size_t s1 = c; s1 > 0; --s1) { + auto out2 = out1; + for (size_t s2 = h; s2 > 0; --s2) { + float* RESTRICT out3 = out2; + size_t s3 = w; + for (; s3 > 4; s3 -= 4) { + auto vf32 = MlasLoadFloat32x4(in); + in += 4; + MlasStoreLaneFloat32x4<0>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<1>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<2>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<3>(out3, vf32); + out3 += set_strides[2]; + } + for (; s3 > 0; --s3) { + //set({s0,s2,s3,s1}, get({s0,s1,s2,s3}),t.get()); + *out3 = *in++; + out3 += set_strides[2]; + } + out2 += set_strides[1]; + } + out1++; + } + out0 += set_strides[0]; + } + } else { + MlasTranspose(in, t.get(), c, d_h*d_w, ThreadPool); + } + + return t; +} + +static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh, + const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, + const float* in_data, + const float* pad_ptr) { + + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = MlasDivRoundup(m, m_step); + auto MaxTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles); + m_step *= MlasDivRoundup(RequiredTiles, MaxTiles); + RequiredTiles = MlasDivRoundup(m, m_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(RequiredTiles), [&](ptrdiff_t tid) { + + auto m_idx = static_cast(tid) * m_step; + auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); + + kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, + lhs_ptrs + m_idx * kh * kw, + reinterpret_cast(in_data), + reinterpret_cast(pad_ptr), + lhs_data + offset + ); + }); +} + +static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const size_t ci, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const float* weights, const float* bias, + MLAS_THREADPOOL* ThreadPool) +{ + //cache of prepacked kai rhs weights and biases + static std::unordered_map> rhs_cache; + + RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; + + auto found = rhs_cache.find(key); + if (found != rhs_cache.end()) { + return found->second; + } else { + // prepare mlas filter weights for kai rhs packing + // dilated nhwc format + auto nhwc = NChwToNhwc(co, ci, kh, kw, weights, dilationh, dilationw, true, ThreadPool); + + + //dilation, axis swap (n x k -> k x n) where n == co, k == d_kh x d_kw x ci + const auto d_kh = ComputeKernelSize(dilationh,kh); + const auto d_kw = ComputeKernelSize(dilationw,kw); + + //t_weights[d_kh][d_kw][ci][co] = nhwc[co][d_kh][d_kw][ci] + auto t_weights = Transpose4D({co,d_kh,d_kw,ci},&nhwc[0],{1,2,3,0}); + + const auto packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(co,d_kh*d_kw,ci); + auto packed = std::shared_ptr(new std::byte[packed_size], std::default_delete()); + + rhs_cache[key] = packed; + + std::vector bias_copy; + if (bias) { + bias_copy.assign(bias, bias + co); + } else { + bias_copy.resize(co, 0.0f); + } + + kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() + ); + + return packed; + } +} + +static std::shared_ptr LhsPtrFill(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, size_t sh, size_t sw, + const size_t padding, + const float* pad_ptr) { + size_t check_filled{0}; + + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto lhs_ptrs_k = kh * kw; + const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step); + auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], + std::default_delete()); + + + auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); + auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); + + auto ptrs_offset = [lhs_ptrs_m,lhs_ptrs_k, m_step](size_t k, size_t m) { + //(m/m_step,transpose(m_step,k) + auto offset {((m/m_step) * lhs_ptrs_k * m_step) + (k*m_step) + (m%m_step)}; + assert(offset < (lhs_ptrs_k * lhs_ptrs_m)); + + MLAS_UNREFERENCED_PARAMETER(lhs_ptrs_m); + + return offset; + }; + + auto pixel_offset = [ih, iw, ci, pad_ptr, padding](size_t h, size_t w) { + if (h < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + h -= padding; + + if (w < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + w -= padding; + + if ((h >= ih) || (w >= iw)) { + return reinterpret_cast(&pad_ptr[0]); + } + + auto offset{h * iw * ci + w * ci}; + assert(offset < (ih*iw*ci)); + return offset*sizeof(float); + }; + + size_t m_{0}; + auto lhs_ptrs_ = lhs_ptrs.get(); + for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { + for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { + size_t k_{0}; + for (size_t kh_ = 0; kh_ < kh; ++kh_) { + for (size_t kw_ = 0; kw_ < kw; ++kw_) { + lhs_ptrs_[ptrs_offset(k_, m_)] = reinterpret_cast(pixel_offset(ih_+kh_, iw_+kw_)); + k_++; check_filled++; + } + } + } + } + + assert(check_filled == (lhs_ptrs_k * m)); + MLAS_UNREFERENCED_PARAMETER(check_filled); + + return lhs_ptrs; +} + +static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, const size_t sh, + const size_t sw, const size_t padding, const float* in, + MLAS_THREADPOOL* ThreadPool) +{ + size_t padsize = 256; + if(ci > padsize) + { + // figure out how many blocks needed to correctly fill padding + padsize = ((ci + padsize - 1) / padsize) * padsize; + } + static std::vectorpad_ptr(padsize, 0.f); + + LhsCacheKey key = { + ci, ih, iw, + padding, sh, sw, + kh, kw, + 1, 1, + HashWeights(in) + }; + + //create lhs in format required for imatmul + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); + auto lhs = std::make_unique(lhs_size); + + auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + + //cache of computed lhs ptr offsets + static std::unordered_map> lhs_ptrs_cache; + + std::shared_ptr lhs_ptrs; + if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + lhs_ptrs = found->second; + } else { + lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); + lhs_ptrs_cache[key] = lhs_ptrs; + } + + MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], &nhwc[0], &pad_ptr[0]); + + return lhs; +} + +static void ConvolveSme(const size_t co, //channels out + const size_t ci, //channels in + const size_t ih, //image height + const size_t iw, //image width + const size_t kh, //kernel height + const size_t kw, //kernel width + const size_t sh, //kernel stride height + const size_t sw, //kernel stride width + const size_t dilationh, //kernel dilation stride + const size_t dilationw, //kernel dilation stride + const size_t padding, //padding size + const size_t groups, //number of filter groups + const float* weights, //kernel weights [co,ci,ih,iw] + const float* bias, //kernel biases + const float* in, //in image data + float* out, //out image data + float* tmp_mlas_aligned, //intermediate buffer if we need to perform a transpose + MLAS_THREADPOOL* ThreadPool) { + + //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights + //compute corrected dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + //run igemm based convolution + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + //tile iteration dimensions + std::array dim; + dim[0] = 1; // B + dim[1] = MlasDivRoundup(m, m_step); // M + dim[2] = MlasDivRoundup(co, n_step); // N + + //Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0]*dim[1]*dim[2]); + + //scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + //compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); + + //update tile iterations + dim[1] = MlasDivRoundup(m, m_step); + dim[2] = MlasDivRoundup(co, n_step); + + for (size_t g = 0; g < groups; ++g) { + + auto result{out}; + //do we require a post matmul transpose ? + //output is m x n or image_data x co or hw x co + //MLAS require it as n x m (or co x hw), transpose required + if (co > 1) { + //intermediate buffer required, pre-transpose + //Note: because we are calling MlasTranspose() need to ensure we use a MLAS aligned buffer + result = tmp_mlas_aligned; + } + + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); + auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); + + + MlasTrySimpleParallel(ThreadPool, + static_cast(dim[0]*dim[1]*dim[2]), + [&](ptrdiff_t tid) + { + //compute B,M,N index from iteration index + //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step, + d_kh*d_kw,ci); + + auto BTile = reinterpret_cast( + reinterpret_cast(rhs.get()) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step, + d_kh*d_kw,ci); + + auto ATile = reinterpret_cast( + reinterpret_cast(lhs.get()) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = &reinterpret_cast(result)[ + MIdx * m_step * co * sizeof(float) + + NIdx * n_step * sizeof(float)]; + + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + }); + + if (result == tmp_mlas_aligned) { + //Note: this could be absorbed into post conv activation + MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); + } + + in += ci * ih * iw; + out += m * co; + weights += co * ci * kh * kw; + if(bias){ + bias += co; + } + } +} + +bool MLASCALL +ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool) +{ + //Check dimensions before accessing + if (Dimensions < 2) { + return false; + } + + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + + size_t InputSize = 1; + size_t OutputSize = 1; + size_t K = InputChannels; + + for (size_t dim = 0; dim < Dimensions; dim++) { + + Parameters->InputShape[dim] = size_t(InputShape[dim]); + Parameters->OutputShape[dim] = size_t(OutputShape[dim]); + Parameters->KernelShape[dim] = size_t(KernelShape[dim]); + Parameters->DilationShape[dim] = size_t(DilationShape[dim]); + Parameters->Padding[dim] = size_t(Padding[dim]); + Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = size_t(StrideShape[dim]); + + InputSize *= Parameters->InputShape[dim]; + OutputSize *= Parameters->OutputShape[dim]; + K *= Parameters->KernelShape[dim]; + } + + Parameters->InputSize = InputSize; + Parameters->OutputSize = OutputSize; + Parameters->K = K; + + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if(!CheckCapabilitiesSme(Parameters)){ + return false; + } + + //Allocate an aligned buffer for MlasTranspose() + *WorkingBufferSize = ComputeMlasWorkingBufferSize(Parameters->FilterCount, + Parameters->InputShape[0], Parameters->InputShape[1], + Parameters->KernelShape[0], Parameters->KernelShape[1], + Parameters->DilationShape[0], Parameters->DilationShape[1], + Parameters->StrideShape[0], Parameters->StrideShape[1], + Parameters->Padding[0]); + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if(!CheckCapabilitiesSme(Parameters)){ + //Fallback to Default Mlas + return false; + }; + ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in + Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions + Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions + Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions + Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation + Parameters->Padding[0], // image padding + Parameters->GroupCount, // filter groups + Filter, Bias, Input, Output, WorkingBuffer, ThreadPool); + + MlasActivation(Parameters->Activation, Output, nullptr, Parameters->FilterCount, Parameters->OutputSize, + Parameters->OutputSize); + return true; +} diff --git a/src/lib/kleidiai/mlasi_kleidiai.h b/src/lib/kleidiai/mlasi_kleidiai.h new file mode 100644 index 0000000..2e9c457 --- /dev/null +++ b/src/lib/kleidiai/mlasi_kleidiai.h @@ -0,0 +1,151 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "mlasi.h" + +// Fix to ensure compatibility with MSVC build +#if defined(_MSC_VER) + #define RESTRICT __restrict +#else + #define RESTRICT __restrict__ +#endif +namespace ArmKleidiAI { +// By default we should try for SME2 first before falling back to SME. +inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); + +// +// Buffer packing routines. +// + +size_t +MLASCALL +MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + +//pack symmetric quantized B and dynamic quantized A +void +MLASCALL +MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ); + +bool +MLASCALL +MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool); + +bool +MLASCALL +MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +} + +/*++ + +Routine Description: + + This routine determines if a wraparound will occur when multiplying two size_t variables + Uses __builtin_mul_overflow if available on the current system and if not falls back + to a default implementation to check this wraparound. + +Arguments: + + a - Supplies the first number to be muliplied. + + b - Supplies the second number to be muliplied. + + out - pointer to a size_t which acts as the return value in success cases. + +Return Value: + + Returns false if the operation was successful + Returns true if wraparound of size_t was detected + +--*/ +inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { +#if defined(__has_builtin) +# if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(a, b, out); +# endif +#endif + // Fallback to manual check if builtin not available + if (b != 0 && a > SIZE_MAX / b) return true; + if (out) *out = a * b; + return false; +} diff --git a/src/lib/kleidiai/qgemm_kleidiai.cpp b/src/lib/kleidiai/qgemm_kleidiai.cpp new file mode 100644 index 0000000..fb38f2c --- /dev/null +++ b/src/lib/kleidiai/qgemm_kleidiai.cpp @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" + +#include "mlasi_kleidiai.h" + +//Matmul with float output of dynamic quantized A and symmetric quantized B. + +size_t +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) { + //Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + //regardless of kernel variant use neon packing variant + return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) { + // Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + // y - float output + // scale_factor_lhs - lhs scaling factor + // scale_factor_rhs - rhs scaling factor + // lhs_q - lhs quantized (asymmetric, so has zero point) + // rhs_q - rhs quantized (symmetric so no zero point) + // lhs_zp - lhs zero point + // y = (1/(scale_factor_lhs * scale_factor_rhs) * sum( (lhs_q + lhs_zp)*rhs_q )) + bias + + // rhs packing requires lhs_zp because it will perform lhs_zp*rhs_q during rhs packing + // because lhs quantization is hidden from us, by lhs quant packing, we don't have a value for lhs_zp it is + // lhs dynamic quantization + + kai_rhs_pack_qsi8cx_params params{ + 1, // lhs_zp - set to 1 so it becomes sum((lhs_q + 1)*rhs_q )), + // the actual lhs_zp is applied during the matmul + 1.f // it is not used + }; + + //regardless of kernel variant use neon packing variant + kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(1, N, K, nr, kr, sr, B, + // N bias values + Bias, + // N scale values + Scales, PackedB, 0, ¶ms); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { + for (auto b = BatchN; b > 0; --b,++DataParams) { + auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + + //TODO enable multi-threading for lhs packing and matmul + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + //Dynamic Quantize A - lhs + auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* lhs = nullptr; + std::unique_ptr fallback; + + if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { + lhs = static_cast(DataParams->Workspace); + } else { + fallback = std::make_unique(lhs_size); + lhs = fallback.get(); + } + + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, + Shape.K*sizeof(float), lhs); + + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, + DataParams->C, + Shape.N * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } +} diff --git a/src/lib/kleidiai/sgemm_kleidiai.cpp b/src/lib/kleidiai/sgemm_kleidiai.cpp new file mode 100644 index 0000000..435ff1f --- /dev/null +++ b/src/lib/kleidiai/sgemm_kleidiai.cpp @@ -0,0 +1,432 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlasi_kleidiai.h" + + +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffers { + std::vector output_tile; + std::vector bias_zero; + std::vector rhs_packed; + std::vector lhs_packed; +}; +static thread_local KaiTlsBuffers g_kai_tls; + +size_t +MLASCALL +ArmKleidiAI::MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + if (TransA == CblasNoTrans) { + switch (TransB) { + case CblasNoTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + case CblasTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + default: + return 0; + } + } else { + return 0; + } + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (N == 0 || K == 0) { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + // Ensure size and zero the used span. + g_kai_tls.bias_zero.resize(N, 0.0f); + + switch (TransB) { + case CblasNoTrans: + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); + break; + case CblasTrans: + kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); + break; + default: + return false; + } + return true; + } + else{ + return false; + } +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +/*++ + +Routine Description: + + This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels. + It handles both packed and unpacked inputs and manages tiling and kernel selection depending on + SME2 availability. If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + + The function also applies alpha and beta scaling to the result, supports efficient memcpy + paths where possible, and dispatches tile-level GEMM work using multithreading. + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (M == 0 || N == 0) { + return true; + } + + if (Data->alpha == 0.0f || K == 0) { + if (Data->beta == 0.0f) { + for (size_t i = 0; i < M; ++i) { + std::fill_n(Data->C + i * Data->ldc, N, 0.0f); + } + } else if (Data->beta != 1.0f) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + Data->C[i * Data->ldc + j] *= Data->beta; + } + } + } + return true; + } + + const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + if ((M < m_step || N < n_step) && !Data->BIsPacked) { + // Fallback to MLAS + return false; + } + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + + size_t lhs_resize = 0; + if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + { + // size_t wraparound detected for LhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.lhs_packed.resize(lhs_resize); + LhsPackedData = g_kai_tls.lhs_packed.data(); + + // RHS packed buffer: use TLS reusable vector to minimize allocations + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + }); + } else { + // Multithread pack lhs and rhs + RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); + size_t rhs_resize = 0; + if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + { + // size_t wraparound detected for RhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.rhs_packed.resize(rhs_resize); + RhsPackedData = g_kai_tls.rhs_packed.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + if (batch_idx & 0x1) { + batch_idx >>= 1; + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + } else { + batch_idx >>= 1; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, + reinterpret_cast(Data[batch_idx].B), + Data[batch_idx].ldb, RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. + // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. + size_t max_tile_elems = 0; + if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + // size_t wraparound detected for tile size, fallback to MLAS + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K) + : kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K); + + const std::byte* B_base = Data[0].BIsPacked + ? reinterpret_cast(Data[BIdx].B) + : (RhsPackedData + RhsPackedStride * BIdx); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = + UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K) + : kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K); + + const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result (TLS reusable buffer) + size_t tile_elems = TileSizeM * TileSizeN; + + // resize the tile to the required size + g_kai_tls.output_tile.resize(tile_elems); + + float* temp_tile = g_kai_tls.output_tile.data(); + std::fill_n(temp_tile, tile_elems, 0.0f); + + if (UseSME2) { + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } else { + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); + + // quick copy of data in cases where we are not scaling or accumulating anything + // with bounds checking on tile sizing to ensure the data fits in the memory block + bool can_memcpy = ( + Data[BIdx].alpha == 1.0f && + Data[BIdx].beta == 0.0f && + Data[BIdx].ldc == TileSizeN && + MIdx * m_step + TileSizeM <= M && + NIdx * n_step + TileSizeN <= N && + TileSizeM != 0 && + TileSizeN != 0); + + if (can_memcpy) { + std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); + return; + } + + float alpha = Data[BIdx].alpha; + float beta = Data[BIdx].beta; + size_t ldc = Data[BIdx].ldc; + + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t temp_idx = i * TileSizeN + j; + const size_t dst_idx = i * ldc + j; + + float ab = temp_tile[temp_idx]; + float c_orig = dst_tile[dst_idx]; + + if (alpha == 1.0f && beta == 0.0f) { + dst_tile[dst_idx] = ab; + } else if (alpha == 1.0f) { + dst_tile[dst_idx] = ab + beta * c_orig; + } else if (beta == 0.0f) { + dst_tile[dst_idx] = alpha * ab; + } else { + dst_tile[dst_idx] = alpha * ab + beta * c_orig; + } + } + } + return; + }); + return true; +} diff --git a/src/lib/logistic.cpp b/src/lib/logistic.cpp index ecca39f..f8244b7 100644 --- a/src/lib/logistic.cpp +++ b/src/lib/logistic.cpp @@ -110,7 +110,10 @@ Return Value: q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_2)); q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_0)); - MlasStoreFloat32x4(Output, MlasAddFloat32x4(MlasDivideFloat32x4(p, q), MlasBroadcastFloat32x4(0.5f))); + MlasStoreFloat32x4(Output, MlasClampFloat32x4( + MlasAddFloat32x4(MlasDivideFloat32x4(p, q), MlasBroadcastFloat32x4(0.5f)), + 0.0f, + 1.0f)); Input += 4; Output += 4; @@ -145,7 +148,7 @@ Return Value: q = q * ValueSquared + MlasLogisticConstants.beta_2; q = q * ValueSquared + MlasLogisticConstants.beta_0; - *Output++ = (p / q) + 0.5f; + *Output++ = std::clamp((p / q) + 0.5f, 0.0f, 1.0f); N -= 1; } @@ -178,7 +181,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) GetMlasPlatform().LogisticKernelRoutine(Input, Output, N); #else MlasLogisticKernel(Input, Output, N); diff --git a/src/lib/mlasi.h b/src/lib/mlasi.h index 13ea8d9..d25f1f2 100644 --- a/src/lib/mlasi.h +++ b/src/lib/mlasi.h @@ -18,6 +18,7 @@ Module Name: #pragma once #include +#include #include #include #include @@ -69,6 +70,9 @@ Module Name: #undef pixel #undef bool #endif +#if defined(__s390x__) +#include +#endif #if defined(__loongarch64) #include #endif @@ -117,9 +121,8 @@ Module Name: #ifdef MLAS_NO_EXCEPTION -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) +MLAS_FORCEINLINE void + MlasPrintFinalMessage(const std::string& msg) { #if defined(__ANDROID__) __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); @@ -133,6 +136,7 @@ MlasPrintFinalMessage(const std::string& msg) #endif } + #define MLAS_THROW_EX(ex, what) \ do { \ std::string msg = #ex; \ @@ -161,7 +165,7 @@ MlasPrintFinalMessage(const std::string& msg) #include "core/common/cpuid_info.h" using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; -#include "core/framework/float16.h" +#include "core/common/float16.h" #else // BUILD_MLAS_NO_ONNXRUNTIME @@ -191,6 +195,8 @@ class MLASCPUIDInfo bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE() const { return has_arm_sve_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } @@ -201,6 +207,7 @@ class MLASCPUIDInfo bool has_arm_neon_dot_{false}; bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; }; @@ -260,6 +267,22 @@ struct MLFloat16 { MLFloat16() = default; explicit constexpr MLFloat16(uint16_t x) : val(x) {} explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} + constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); } + + MLFloat16 Abs() const noexcept { + return MLFloat16(static_cast(val & ~kSignMask)); + } + bool IsNaN() const noexcept { + return Abs().val > kPositiveInfinityBits; + } + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + MLFloat16 Negate() const { + return MLFloat16(IsNaN() ? val : static_cast(val ^ kSignMask)); + } + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; float ToFloat() const { return MLAS_Half2Float(val); } @@ -301,6 +324,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the default strides to step through slices of the input matrices. // +#define MLAS_HGEMM_STRIDEN 128 +#define MLAS_HGEMM_STRIDEK 128 #define MLAS_SGEMM_STRIDEN 128 #define MLAS_SGEMM_STRIDEK 128 #define MLAS_SGEMM_PACKED_STRIDEN 128 @@ -317,6 +342,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // the effort at this time. // +#define MLAS_HGEMM_STRIDEN_THREAD_ALIGN 32 #define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_DGEMM_STRIDEN_THREAD_ALIGN 8 #define MLAS_QGEMM_STRIDEN_THREAD_ALIGN 16 @@ -326,7 +352,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_TARGET_LARCH64) + defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) typedef size_t @@ -358,6 +384,22 @@ size_t bool ZeroMode ); +#ifdef FORCE_GENERIC_ALGORITHMS +typedef +size_t +(MLASCALL MLAS_GEMM_FLOAT_KERNEL_GENERIC)( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ); +#endif + #else #if defined(__aarch64__) && defined(__linux__) @@ -711,6 +753,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -727,12 +787,129 @@ struct MLAS_QUANT_KERNEL size_t KernelSize ); }; +typedef +void +(MLASCALL MLAS_CONV_FLOAT_FN)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_FLOAT_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +// TODO: Investigate if overridden typedefs can be removed +typedef +void +(MLASCALL MLAS_CONV_PREPARE_FLOAT_FN)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_PREPARE_FLOAT_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +typedef void (MLASCALL MLAS_GEMM_BATCH)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); + +typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx; +#ifdef FORCE_GENERIC_ALGORITHMS + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelZero; + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelAdd; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F; @@ -748,6 +925,12 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; +#elif defined(MLAS_TARGET_S390X) + MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel; + MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZVECTOR; + MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel; + MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelZVECTOR; + MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelZVECTOR; #elif defined(MLAS_TARGET_LARCH64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; @@ -778,6 +961,15 @@ extern "C" { #if defined(__aarch64__) && defined(__linux__) MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; @@ -863,6 +1055,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -924,6 +1118,7 @@ extern "C" { #define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#define MLAS_HGEMM_THREAD_COMPLEXITY 65536 #if defined(__aarch64__) && defined(__linux__) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) @@ -972,8 +1167,14 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchZVECTOR; + +#if defined(MLAS_TARGET_WASM_RELAXED_SIMD) +extern bool HasUSDot(); +#endif // // Symmetric quantized qgemm dispatch structure @@ -1017,17 +1218,44 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; // Float/quantized n-bit integer matrix/matrix multiply dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH; +struct MLAS_QNBIT_GEMM_DISPATCH; + +const MLAS_QNBIT_GEMM_DISPATCH& +GetMlasQNBitGemmDispatchNeon( + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport +); + +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; + +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; +// +// Rotary embedding dispatch structure. +// +struct MLAS_ROPE_DISPATCH; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; + +// +// half gemm dispatch structure +// +struct MLAS_HGEMM_DISPATCH; +extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; +// softmax dispatch structure +struct MLAS_SOFTMAX_DISPATCH; +extern const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +// eltwise dispatch structure +struct MLAS_ELTWISE_DISPATCH; +extern const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon; // // Quantized depthwise convolution kernels. @@ -1091,7 +1319,19 @@ struct MLAS_PLATFORM { MLAS_PLATFORM(void); -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + // TODO: move to cpuinfo + bool Avx2Supported_ = false; + bool Avx512Supported_ = false; + bool ArmNeonIsQuantActivationsUnsigned = false; + + // Mlas overrides initialisation + MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; + MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; + MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr; + MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; + MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; + +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif #if defined(MLAS_TARGET_LARCH64) @@ -1119,6 +1359,14 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; +#if defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + uint32_t NchwcBlockSize; +#endif #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; @@ -1130,7 +1378,7 @@ struct MLAS_PLATFORM { MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8S8Kernel; MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8U8Kernel; -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; @@ -1140,6 +1388,15 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif + +#if defined(MLAS_USE_SVE) || defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; + MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; @@ -1155,16 +1412,10 @@ struct MLAS_PLATFORM { MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; MLAS_QLINEAR_BINARY_OP_S8_KERNEL* QLinearAddS8Kernel; MLAS_QLINEAR_BINARY_OP_U8_KERNEL* QLinearAddU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ComputeExpF32Kernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* TanhKernelRoutine; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; @@ -1172,11 +1423,14 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; #elif defined(MLAS_TARGET_ARM64) static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; + static constexpr size_t MLAS_NEON_NCHWC_BLOCK_SIZE = 16; #else static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif @@ -1184,10 +1438,15 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + + const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; + const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; + const MLAS_SOFTMAX_DISPATCH* SoftmaxDispatch{nullptr}; + const MLAS_ELTWISE_DISPATCH* EltwiseDispatch{nullptr}; }; inline @@ -1381,6 +1640,8 @@ MlasConvDepthwiseFloat_CHW( #define MLAS_NEON64_INTRINSICS #elif defined(MLAS_TARGET_POWER) #define MLAS_VSX_INTRINSICS +#elif defined(MLAS_TARGET_S390X) +#define MLAS_ZVECTOR_INTRINSICS #elif defined(MLAS_TARGET_AMD64_IX86) #define MLAS_SSE2_INTRINSICS #if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) @@ -1397,6 +1658,9 @@ MlasConvDepthwiseFloat_CHW( #endif #elif defined(MLAS_TARGET_WASM_SIMD) #define MLAS_WASM_SIMD_INTRINSICS +#if defined(MLAS_TARGET_WASM_RELAXED_SIMD) +#define MLAS_WASM_RELAXED_SIMD_INTRINSICS +#endif #elif defined(MLAS_TARGET_LARCH64) #define MLAS_LSX_INTRINSICS #endif @@ -1447,6 +1711,8 @@ MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) return _mm_cvttps_epi32(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_cts(Vector, 0); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_signed(Vector); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vftint_w_s(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) @@ -1466,6 +1732,8 @@ MlasCastToFloat32x4(MLAS_INT32X4 Vector) return _mm_cvtepi32_ps(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_ctf(Vector, 0); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_float(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_convert_i32x4(Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1485,7 +1753,7 @@ MlasBroadcastInt32x4(int32_t Value) return _mm_set1_epi32(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_splats(Value); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vreplgr2vr_w(Value); @@ -1504,6 +1772,8 @@ MlasLoadInt32x4(const int32_t* Buffer) return _mm_loadu_si128((const __m128i*)Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); #elif defined(MLAS_LSX_INTRINSICS) @@ -1523,6 +1793,8 @@ MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) _mm_storeu_si128((__m128i*)Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1629,7 +1901,7 @@ MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_xor_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_xor(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_xor(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vxor_v(Vector1, Vector2); @@ -1675,6 +1947,8 @@ MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector1, Vector2)); #elif defined(MLAS_VSX_INTRINSICS) return vec_vmaxsw(Vector1, Vector2); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_max(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_max(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -1696,6 +1970,8 @@ MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector2, Vector1)); #elif defined(MLAS_VSX_INTRINSICS) return vec_vminsw(Vector1, Vector2); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_min(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_min(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -1730,7 +2006,7 @@ MlasBroadcastFloat32x4(float Value) return _mm_set1_ps(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Suppress wrong GCC warnings MLAS_UNREFERENCED_PARAMETER(Value); return vec_splats(Value); @@ -1751,7 +2027,7 @@ MlasBroadcastFloat32x4(const float* Value) return _mm_load_ps1(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load32_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_splats(*Value); #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; @@ -1787,6 +2063,8 @@ MlasLoadFloat32x4(const float* Buffer) return _mm_loadu_ps(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); #elif defined(MLAS_LSX_INTRINSICS) @@ -1807,6 +2085,8 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_storeu_ps(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1829,6 +2109,8 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) MLAS_UNREFERENCED_PARAMETER(Buffer); MLAS_UNREFERENCED_PARAMETER(Vector); vec_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1963,7 +2245,7 @@ MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return zipped.val[0]; #elif defined(MLAS_SSE2_INTRINSICS) return _mm_unpacklo_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_mergeh(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); @@ -1983,7 +2265,7 @@ MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return zipped.val[1]; #elif defined(MLAS_SSE2_INTRINSICS) return _mm_unpackhi_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_mergel(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); @@ -2057,13 +2339,20 @@ MLAS_FLOAT32X4 MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3) { #if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_TARGET_ARM) + // ARMv7 NEON doesn't have vfmaq_f32() return vmlaq_f32(Vector3, Vector1, Vector2); +#else + return vfmaq_f32(Vector3, Vector1, Vector2); +#endif #elif defined(MLAS_FMA3_INTRINSICS) return _mm_fmadd_ps(Vector1, Vector2, Vector3); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_add_ps(_mm_mul_ps(Vector1, Vector2), Vector3); #elif defined(MLAS_VSX_INTRINSICS) return vec_madd(Vector1, Vector2, Vector3); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return __builtin_s390_vfmasb(Vector1, Vector2, Vector3); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); #elif defined(MLAS_LSX_INTRINSICS) @@ -2120,7 +2409,7 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_cmpgt_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_gt(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); @@ -2204,9 +2493,11 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vmaxq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_max_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Don't use vec_max to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); +#elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) + return wasm_f32x4_relaxed_max(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_max(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -2224,9 +2515,11 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vminq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_min_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Don't use vec_min to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); +#elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) + return wasm_f32x4_relaxed_min(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_min(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -2263,7 +2556,7 @@ MlasReduceAddFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpadd_f32(VectorLow, VectorHigh); VectorLow = vpadd_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasAddFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasAddFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2286,7 +2579,7 @@ MlasReduceMaximumFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpmax_f32(VectorLow, VectorHigh); VectorLow = vpmax_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasMaximumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasMaximumFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2309,7 +2602,7 @@ MlasReduceMinimumFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpmin_f32(VectorLow, VectorHigh); VectorLow = vpmin_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasMinimumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasMinimumFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2335,7 +2628,7 @@ MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) #if defined(MLAS_SSE2_INTRINSICS) typedef __m128d MLAS_FLOAT64X2; -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) typedef __vector double MLAS_FLOAT64X2; #elif defined(MLAS_LSX_INTRINSICS) typedef __m128d MLAS_FLOAT64X2; @@ -2345,7 +2638,7 @@ typedef __m128d MLAS_FLOAT64X2; #ifndef MLAS_FLOAT64X2_UNSUPPORTED -#if defined(MLAS_VSX_INTRINSICS) +#if defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) template MLAS_FORCEINLINE double @@ -2394,7 +2687,7 @@ MlasBroadcastFloat64x2(double Value) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_set1_pd(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; @@ -2407,7 +2700,7 @@ MlasZeroFloat64x2(void) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_setzero_pd(); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); #elif defined(MLAS_LSX_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); @@ -2422,6 +2715,8 @@ MlasLoadFloat64x2(const double* Buffer) return _mm_loadu_pd(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); #endif @@ -2435,6 +2730,8 @@ MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_storeu_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_LSX_INTRINSICS) (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif @@ -2446,7 +2743,7 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) { #if defined(MLAS_SSE2_INTRINSICS) _mm_store_pd(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) *((MLAS_FLOAT64X2*)Buffer) = Vector; #elif defined(MLAS_LSX_INTRINSICS) (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); @@ -2459,7 +2756,7 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_mul_pd(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return Vector1 * Vector2; #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vfmul_d(Vector1, Vector2); diff --git a/src/lib/platform.cpp b/src/lib/platform.cpp index ed56d82..796bfd1 100644 --- a/src/lib/platform.cpp +++ b/src/lib/platform.cpp @@ -16,6 +16,12 @@ Module Name: --*/ #include "mlasi.h" +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif #include #include @@ -31,6 +37,11 @@ Module Name: #endif #endif + +#if defined(MLAS_TARGET_S390X) +#include +#endif + #if defined(MLAS_TARGET_ARM64) #if defined(_WIN32) @@ -168,7 +179,6 @@ MlasReadExtendedControlRegister( #if defined(__linux__) #include -#include #endif bool @@ -286,8 +296,14 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ +#ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#else // FORCE_GENERIC_ALGORITHMS + this->CastF16ToF32Kernel = nullptr; +#endif // FORCE_GENERIC_ALGORITHMS #endif // __APPLE__ this->NchwcBlockSize = 8; @@ -309,8 +325,11 @@ Return Value: // // Check if the processor supports SSE 4.1 instructions. // - +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x80000) != 0) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41; } @@ -320,7 +339,11 @@ Return Value: // Check if the processor supports the AVX and OSXSAVE features. // +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x18000000) == 0x18000000) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS // // Check if the operating system supports saving SSE and AVX states. @@ -364,6 +387,8 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->Avx2Supported_ = true; + this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; @@ -388,9 +413,10 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + this->RopeDispatch = &MlasRopeDispatchAvx2; // @@ -418,7 +444,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) @@ -454,12 +480,14 @@ Return Value: if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { + this->Avx512Supported_ = true; + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -472,7 +500,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } @@ -532,27 +560,36 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->RopeDispatch = &MlasRopeDispatchNeon; + this->HGemmDispatch = &MlasHGemmDispatchNeon; + this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; + this->EltwiseDispatch = &MlasEltwiseDispatchNeon; + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; + this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE; +#endif // // Check if the processor supports ASIMD dot product instructions. // - bool HasDotProductInstructions; - -#if defined(_WIN32) - HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else - // Use the cpuinfo value which is read from sysctl and has some additional special cases. - // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Note: // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips // as well as failing on other ARM chips as it is an EL1 level register that requires extra // privileges to read. // // uint64_t isar0_el1; // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; - HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); -#endif + // const bool HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; + + const bool HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); if (HasDotProductInstructions) { this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot; @@ -561,21 +598,53 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + } - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; + this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; + this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; + this->MlasConvOverride = ArmKleidiAI::MlasConv; } +#endif + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + this->ErfKernelRoutine = MlasSveErfKernel; + this->LogisticKernelRoutine = MlasSveLogisticKernel; + this->ReduceMaximumF32Kernel = MlasSveReduceMaximumF32Kernel; + this->ComputeSumExpF32Kernel = MlasSveComputeSumExpF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasSveComputeLogSoftmaxOutputF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasSveComputeSoftmaxOutputF32Kernel; + } + else{ + this->ErfKernelRoutine = MlasErfKernel; + this->LogisticKernelRoutine = MlasLogisticKernel; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + } +#endif -#if defined(__linux__) // // Check if the processor supports ASIMD I8MM instructions. // - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + + const bool HasI8MMInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM(); + if (HasI8MMInstructions) { +#if defined(__linux__) + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; - } #endif + } + + this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true; + this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; @@ -599,6 +668,11 @@ Return Value: bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; #elif defined(_AIX) bool HasP9Instructions = __power_9_andup(); +#elif defined(__FreeBSD__) + unsigned long hwcap2; + elf_aux_info(AT_HWCAP2, &hwcap2, sizeof(hwcap2)); + + bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; #endif // __linux__ if (HasP9Instructions) { this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelVSX; @@ -608,7 +682,7 @@ Return Value: #if defined(POWER10) #if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ (defined(__clang__) && (__clang_major__ >= 12)) -#if defined(__linux__) +#if defined(__linux__) || defined(__FreeBSD__) bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); #elif defined(_AIX) bool HasP10Instructions = (__power_10_andup() && __power_mma_version() == MMA_V31); @@ -623,6 +697,26 @@ Return Value: #endif // MLAS_TARGET_POWER +#if defined(MLAS_TARGET_S390X) + this->GemmFloatKernel = MlasSgemmKernel; + this->GemmDoubleKernel = MlasDgemmKernel; + this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; + this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + + bool HasVXEInstructions = getauxval(AT_HWCAP) & HWCAP_S390_VXE; + if (HasVXEInstructions) { + this->GemmFloatKernel = MlasSgemmKernelZVECTOR; + this->GemmU8X8Dispatch = &MlasGemm8X8DispatchZVECTOR; + + this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelZVECTOR; + this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelZVECTOR; + } +#endif // MLAS_TARGET_S390X + #if defined(MLAS_TARGET_LARCH64) // @@ -648,6 +742,9 @@ Return Value: this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + // add new sqn-lasx kernel + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchLasx; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; }else if( cap_lsx ){ diff --git a/src/lib/pooling.cpp b/src/lib/pooling.cpp index 50dcf19..6bb23df 100644 --- a/src/lib/pooling.cpp +++ b/src/lib/pooling.cpp @@ -1533,7 +1533,7 @@ Return Value: c -= 8; } -#elif defined(MLAS_TARGET_POWER) +#elif defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) while (c >= 32) { auto MaximumVector0 = vec_splats(std::numeric_limits::lowest()); diff --git a/src/lib/pooling_fp16.cpp b/src/lib/pooling_fp16.cpp index 98e8473..9765192 100644 --- a/src/lib/pooling_fp16.cpp +++ b/src/lib/pooling_fp16.cpp @@ -80,23 +80,23 @@ PoolInit16x4() } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasMaximumFloat16x8(agg, element); + return MlasMaximumFloat16(agg, element); } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasMaximumFloat16x4(agg, element); + return MlasMaximumFloat16(agg, element); } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolSummary16x8(MLAS_FLOAT16X8 agg, size_t size) { @@ -105,7 +105,7 @@ PoolSummary16x8(MLAS_FLOAT16X8 agg, size_t size) } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolSummary16x4(MLAS_FLOAT16X4 agg, size_t size) { @@ -144,28 +144,28 @@ template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasAddFloat16x8(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasAddFloat16x4(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolSummary16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 context) { - return MlasDivFloat16x8(agg, context); + return MlasDivideFloat16(agg, context); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolSummary16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X8 context) { - return MlasDivFloat16x4(agg, MlasToLowHalfFloat16x4(context)); + return MlasDivideFloat16(agg, MlasToLowHalfFloat16x4(context)); } diff --git a/src/lib/power/qgemm_kernel_power10.cpp b/src/lib/power/qgemm_kernel_power10.cpp index 0f3bc1d..b00e37b 100644 --- a/src/lib/power/qgemm_kernel_power10.cpp +++ b/src/lib/power/qgemm_kernel_power10.cpp @@ -437,49 +437,51 @@ MlasGemmQuantCopyPackA8x8( Vtype a2 = vmask; Vtype a3 = vmask; Vtype a1 = *reinterpret_cast(&a[0]); - if (CountM == 3) { - a3 = *reinterpret_cast(&a[lda * 2]); - } - if (CountM >= 2) { + if (CountM == 1) { + vec_t va1 = AIsSigned ? reinterpret_cast(a1) : reinterpret_cast(vec_sub(a1, vmask)); + *reinterpret_cast(&D[0]) = (vec_t)va1; + vsum = vec_sum4s(va1, vsum); + } else { a2 = *reinterpret_cast(&a[lda]); + if (CountM == 3) { + a3 = *reinterpret_cast(&a[lda * 2]); + } + Vtype vx = + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); + Vtype vx1 = + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); + Vtype vx2 = + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); + Vtype vx3 = + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); + vec_t vx0 = AIsSigned ? reinterpret_cast(vx4) : reinterpret_cast(vec_sub(vx4, vmask)); + *reinterpret_cast(&D[0]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : reinterpret_cast(vec_sub(vx5, vmask)); + *reinterpret_cast(&D[16]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : reinterpret_cast(vec_sub(vx6, vmask)); + *reinterpret_cast(&D[32]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : reinterpret_cast(vec_sub(vx7, vmask)); + *reinterpret_cast(&D[48]) = vx0; + vsum = vec_sum4s(vx0, vsum); + } + if (CountM == 1) { + D += 16; + } else { + D += 16 * 4; } - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi(vx, vx1, 0); - Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi(vx, vx1, 3); - Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); - vec_t vx0 = - AIsSigned ? reinterpret_cast(vx4) : - reinterpret_cast(vec_sub(vx4, vmask)); - *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx5) : - reinterpret_cast(vec_sub(vx5, vmask)); - *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx6) : - reinterpret_cast(vec_sub(vx6, vmask)); - *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx7) : - reinterpret_cast(vec_sub(vx7, vmask)); - *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s(vx0, vsum); - D += 16 * 4; a += 16; y -= 16; } + if (CountM == 1) { + vsum[0] += (vsum[1] + vsum[2] + vsum[3]); + } while (y >= 4) { Vtype vb = vmask; @@ -496,7 +498,11 @@ MlasGemmQuantCopyPackA8x8( reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; vsum = vec_sum4s(vx, vsum); - D += 16; + if (CountM == 1) { + D += 4; + } + else + D += 16; a += 4; y -= 4; } @@ -1059,6 +1065,186 @@ MlasQgemmComputeMMA( } } }; + +MLAS_FORCEINLINE +void +MlasGemmQuantKernel_M1( + const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType *A, + const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType *B, + int32_t *C, + size_t PackedCountK, + size_t CountN, + size_t ldc, + const int32_t *RowSumBuffer, + const int32_t *ColumnSumBuffer, + const int32_t *ZeroPointB, + bool ZeroMode +) +{ + size_t Mval = 1; + while (CountN > 0) { + const int8_t *a = A; + typedef __vector unsigned char vec_t; + typedef __vector signed char svec_t; + const uint8_t *b = B; + MLAS_INT32X4 result = {0}; + __vector signed int VecC = {0, 0, 0, 0}; + __vector signed int VecC2 = {0, 0, 0, 0}; + __vector signed int VecC3 = {0, 0, 0, 0}; + __vector signed int VecC4 = {0, 0, 0, 0}; + size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK; + size_t k1 = PackedCountK; + __vector unsigned char va[4]; + __vector unsigned char pat = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + __vector unsigned char pat2 = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + __vector unsigned char pat3 = {8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11}; + __vector unsigned char pat4 = {12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + while (k >= 16) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + va[2] = vec_perm(vecA[0], vecA[0], pat3); + va[3] = vec_perm(vecA[0], vecA[0], pat4); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC); + VecC = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2); + VecC2 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3); + VecC3 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4); + VecC4 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC4); + b += 64; + a += 16; + k -= 16; + } + if (k >= 12) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + va[2] = vec_perm(vecA[0], vecA[0], pat3); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4); + a += 12; + b += 48; + k -= 12; + } + if (k >= 8) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + a += 8; + b += 32; + k -= 8; + } + if (k >= 4) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + a += 4; + b += 16; + k -= 4; + } + if (CountN >= 16) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorMMA<12>(&VecC4, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); + INC_BUFFER(16); + CountN -= 16; + B += 16 * 4 * PackedCountK; + C += 16; + } else { + if (CountN >= 12) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + INC_BUFFER(12); + if (CountN - 12 > 0) + result = VecC4; + CountN -= 12; + C += 12; + } else if (CountN >= 8) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + INC_BUFFER(8); + if (CountN - 8 > 0) + result = VecC3; + CountN -= 8; + C += 8; + } else if (CountN >= 4) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + INC_BUFFER(4); + if (CountN - 4 > 0) + result = VecC2; + CountN -= 4; + C += 4; + } else + result = VecC; + CountN &= 3; + + // Output the remaining partial output block. + if (CountN > 0) { + MlasQgemmStoreScalarMMA<0>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + if (CountN >= 2) { + MlasQgemmStoreScalarMMA<1>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + if (CountN >= 3) { + MlasQgemmStoreScalarMMA<2>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + CountN = 0; + } + } +} + template<> size_t MlasGemmQuantKernel( @@ -1075,6 +1261,10 @@ MlasGemmQuantKernel( bool ZeroMode ) { + if (CountM == 1) { + MlasGemmQuantKernel_M1(A, B, C, PackedCountK, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + return 1; + } if (CountM < 8 && CountM >= 4) { CountM = 4; } diff --git a/src/lib/q4_dq.cpp b/src/lib/q4_dq.cpp index df61d3e..c543770 100644 --- a/src/lib/q4_dq.cpp +++ b/src/lib/q4_dq.cpp @@ -20,10 +20,13 @@ Module Name: #include "q4common.h" -template -constexpr size_t BlkQ4BufSize(size_t N, size_t K) { - const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); - return N * KBlocks * T::BlobSize; +template +constexpr +size_t +BlkQ4BufSize(size_t N, size_t K) +{ + const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); + return N * KBlocks * T::BlobSize; } size_t @@ -325,7 +328,7 @@ struct BitsTraits { static constexpr float halfRange = static_cast(kMid - kMin); // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static constexpr int kPackSize = (qbits == 8) ? 1 : ((qbits == 4) ? 2 : ((qbits == 2) ? 4 : 0)); static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); }; @@ -384,12 +387,14 @@ range2scale(float min, float max, ScaleT& scale) /** - * @brief Blockwise quantization methods + * TODO(fajin): use int4/8 for symmetric quantization so the (vq - zp) operation in MatMulNBits can be saved. + * @brief Blockwise quantization methods. Source is row major. Dest, scale and zp are column major. + * Always quantize to unsigned int. * @tparam ElementT source data type, e.g. fp32/fp16 * @tparam block_size number of elemenets quantized together * @tparam qbits number of bits in each quantized element - * @tparam Columnwise true: elements in a block come from one single column - * false: elements in a block come from one single row + * @tparam Columnwise true: quantize along src column, pack along src column. + * false: quantize along src row, pack along src column. */ template < typename ElementT, @@ -399,11 +404,18 @@ template < struct BlockwiseQuantizer { // To support other qbits, need to add bit packing code for // storing to dst and zero points - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 2 || qbits == 4 || qbits == 8, "Only 2b, 4b and 8b block quantization is supported!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + static + MLAS_FORCEINLINE + int GetElem(int val, int idx) + { + return (val >> (qbits * idx)) & ((1 << qbits) - 1); + } + static MLAS_FORCEINLINE void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) @@ -437,14 +449,14 @@ struct BlockwiseQuantizer { scale_num_elements = meta_rows * meta_cols; if (zero_point_bytes) { - // this works for qbits == 4 but may need to be updated for other qbits values + // this works for qbits == 2, 4 or 8 but may need to be updated for other qbits values *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; } } /** * @brief Quantized a Matrix shape [rows, columns], resulting quantized - * and packed data are stored in column major (transposed) + * and packed data are stored in column major (transposed). * @param[out] dst pointer to the quantized weights, column major: [columns, rows] * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] * @param[out] zero_points pointer to the zero points, same shape as scale @@ -476,8 +488,10 @@ struct BlockwiseQuantizer { MlasTryBatchParallel( thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + constexpr int kPackSize = BitsTraits::kPackSize; + uint8_t zp_bytes[kPackSize], vi[kPackSize]; + std::fill_n(zp_bytes, kPackSize, (uint8_t)BitsTraits::kMid); + std::fill_n(vi, kPackSize, 0); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -492,7 +506,7 @@ struct BlockwiseQuantizer { const int meta_col = c / QuantBlk::kColumn; // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + for (int kpack = 0; kpack < kPackSize; kpack++) { // scan a single block to extract range [min, max] float min = std::numeric_limits::max(); @@ -518,40 +532,42 @@ struct BlockwiseQuantizer { } } - // !! 4b specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; - zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + const int32_t meta_idx = meta_col * ((row_blks + kPackSize - 1) / kPackSize) + meta_row / kPackSize; + if constexpr (qbits == 8) { + zero_points[meta_idx] = zp_bytes[0]; + } else if constexpr (qbits == 4) { + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } else if constexpr (qbits == 2) { + zero_points[meta_idx] = (zp_bytes[0] & 0x3) | (zp_bytes[1] << 2) | (zp_bytes[2] << 4) | (zp_bytes[3] << 6); + } else { + MLAS_THROW_EX(std::runtime_error, "Unsupported qbits"); + } } for (int32_t j = c; j < c_end; ++j) { const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_r = i / QuantBlk::kRow; - const float scale = static_cast(scales[meta_c * row_blks + meta_r]); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = - static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; - } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + for (int32_t i = r; i < r_end; i += kPackSize) { + for (int l = 0; l < kPackSize && i + l < r_end; l++) { + const int32_t meta_r = (i + l) / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int32_t zp = zp_bytes[meta_r % kPackSize]; + + const float v = static_cast(src[(i + l) * leadingDimension + j]); + vi[l] = (uint8_t)std::clamp(roundf(v * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); } - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + if constexpr (qbits == 8) { + dst[j * q_rows + i / kPackSize] = vi[0]; + } else if constexpr (qbits == 4) { + dst[j * q_rows + i / kPackSize] = (vi[0] & 0xf) | (vi[1] << 4); + } else if constexpr (qbits == 2) { + dst[j * q_rows + i / kPackSize] = (vi[0] & 0x3) | (vi[1] << 2) | (vi[2] << 4) | (vi[3] << 6); + } else { + MLAS_THROW_EX(std::runtime_error, "Unsupported qbits"); + } } } }); @@ -586,6 +602,7 @@ struct BlockwiseQuantizer { int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); + constexpr int32_t kPackSize = BitsTraits::kPackSize; MlasTryBatchParallel( thread_pool, total_thrd_blks, @@ -602,38 +619,22 @@ struct BlockwiseQuantizer { for (int32_t j = c; j < c_end; ++j) { const int32_t meta_col = j / QuantBlk::kColumn; - // !! 4b specific code - // the whole loop is 4b specific due to sub 8 bit packing - // and unpacking. We can potentially make this qbits generic - // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; ++i) { const int32_t meta_row = i / QuantBlk::kRow; - - const float scale0 = - static_cast(scales[meta_col * row_blks + meta_row]); - + const float scale = static_cast(scales[meta_col * row_blks + meta_row]); const int zp_pair = - (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); - - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = - static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; - } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); - } + zero_points + ? zero_points[meta_col * ((row_blks + kPackSize - 1) / kPackSize) + meta_row / kPackSize] + : 0; + const int vi_pair = weights[j * q_rows + i / kPackSize]; + + const int zp = + zero_points + ? GetElem(zp_pair, meta_row % kPackSize) + : BitsTraits::kMid; + const int vi = GetElem(vi_pair, i % kPackSize); + const float v = (vi - zp) * scale; + dst[j * rows + i] = ElementT(v); } } }); @@ -1413,6 +1414,27 @@ MlasBlockwiseQuantizedShape( } } +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); template void @@ -1436,6 +1458,50 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); + template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + template void MlasBlockwiseQuantizedShape( @@ -1458,9 +1524,31 @@ MlasBlockwiseQuantizedShape( int& q_cols ); + template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, @@ -1475,75 +1563,108 @@ MlasBlockwiseQuantizedBufferSizes( *q_zero_point_size_in_bytes = 0; } - if (qbits == 4) { - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; } } +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<2>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<4>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<8>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); template void @@ -1617,6 +1738,36 @@ MlasQuantizeBlockwise( } } +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + template void MlasQuantizeBlockwise( @@ -1647,6 +1798,35 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); + template + void + MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + template + void + MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); template void @@ -1714,6 +1894,32 @@ MlasDequantizeBlockwise( } } +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template void MlasDequantizeBlockwise( float* dst, @@ -1727,6 +1933,45 @@ MlasDequantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( diff --git a/src/lib/q4common.h b/src/lib/q4common.h index 74f8058..5febbd8 100644 --- a/src/lib/q4common.h +++ b/src/lib/q4common.h @@ -30,7 +30,6 @@ Module Name: } while (false) #endif - #include "mlas_q4.h" #include "mlasi.h" diff --git a/src/lib/q4gemm.h b/src/lib/q4gemm.h index d16798e..89528fd 100644 --- a/src/lib/q4gemm.h +++ b/src/lib/q4gemm.h @@ -126,7 +126,7 @@ MlasQ4GemmOperation( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); #else diff --git a/src/lib/q4gemm_avx512.cpp b/src/lib/q4gemm_avx512.cpp index fd71777..72d2a62 100644 --- a/src/lib/q4gemm_avx512.cpp +++ b/src/lib/q4gemm_avx512.cpp @@ -18,10 +18,10 @@ Module Name: --*/ #include "q4gemm.h" +#include #include #include -#include struct MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI { static constexpr size_t StrideM = 256; diff --git a/src/lib/qgemm.cpp b/src/lib/qgemm.cpp index 859fcd0..62a23c8 100644 --- a/src/lib/qgemm.cpp +++ b/src/lib/qgemm.cpp @@ -14,10 +14,16 @@ Module Name: operation (QGEMM). --*/ - +#include #include "mlasi.h" #include "qgemm.h" +// TODO: When overrides are implemented, remove this +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + + // // Define the parameters to execute segments of a QGEMM operation on worker // threads. @@ -144,14 +150,7 @@ MlasGemmBatch( const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); if (TargetThreadCount >= MaximumThreadCount) { @@ -202,6 +201,26 @@ MlasGemmBatch( }); } +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback and putting in guards + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(Shape); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} int32_t MlasSymmQgemmGetKernelOutputCnt() @@ -300,15 +319,40 @@ MlasSymmQgemmBatch( }); } + + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) +{ + size_t bytes = 0; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback available + //TODO: Insert Override + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + return bytes; +} + + size_t MLASCALL MlasGemmPackBSize( size_t N, - size_t K, + size_t K, bool AIsSigned, bool BIsSigned ) @@ -361,10 +405,38 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + //If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + //this use case. Concat both packed representations for later decision. + return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); +} - return AlignedBytesRequired; +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) +{ +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(Scales); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(PackedB); } + void MLASCALL MlasGemmPackB( @@ -407,7 +479,6 @@ Return Value: // // Retrieve the packing parameters. // - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); size_t PackedK = GemmQuantDispatch->PackedK; @@ -479,7 +550,7 @@ size_t MLASCALL MlasSymmQgemmPackBSize( size_t N, - size_t K, + size_t K, bool AIsSigned ) { @@ -522,7 +593,6 @@ MlasSymmQgemmPackBSize( #pragma warning(pop) #endif - void MLASCALL MlasSymmQgemmPackB( diff --git a/src/lib/qgemm.h b/src/lib/qgemm.h index 1ef5b5f..2730b1a 100644 --- a/src/lib/qgemm.h +++ b/src/lib/qgemm.h @@ -867,7 +867,8 @@ MlasGemmQuantGetDispatch( { const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) +#if !defined(FORCE_GENERIC_ALGORITHMS) +#if defined(MLAS_TARGET_AMD64_IX86) if (AIsSigned) { GemmQuantDispatch = BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; @@ -885,6 +886,14 @@ MlasGemmQuantGetDispatch( if(BIsSigned || !AIsSigned) { GemmQuantDispatch = &MlasGemmU8X8DispatchNeon; } +#elif defined(MLAS_TARGET_WASM_RELAXED_SIMD) + if (!AIsSigned) { + if (HasUSDot()) { + GemmQuantDispatch = &MlasGemmU8X8DispatchWasmRelaxedSimd; + } else { + GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; + } + } #elif defined(MLAS_TARGET_WASM_SIMD) if (!AIsSigned) { GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; @@ -895,7 +904,17 @@ MlasGemmQuantGetDispatch( if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; } +#elif defined(MLAS_TARGET_LARCH64) + if (!AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; + } +#elif defined(MLAS_TARGET_S390X) + if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchZVECTOR) { + GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; + } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) if (nullptr == GemmQuantDispatch) { std::stringstream ss; diff --git a/src/lib/qgemm_kernel_amx.cpp b/src/lib/qgemm_kernel_amx.cpp index 5d4f1fd..04f1762 100644 --- a/src/lib/qgemm_kernel_amx.cpp +++ b/src/lib/qgemm_kernel_amx.cpp @@ -19,6 +19,7 @@ Module Name: #include "amx_common.h" #include + #define TMM0 0 #define TMM1 1 #define TMM2 2 diff --git a/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp new file mode 100644 index 0000000..56c67aa --- /dev/null +++ b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp @@ -0,0 +1,598 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_wasmrelaxedsimd.cpp + +Abstract: + + This module implements QGEMM kernel for WebAssembly Relaxed SIMD128. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +bool HasUSDot() { +// Check out-of-bounds behavior of Relaxed Integer Dot Product with Accumulation with signed and unsigned input (e.g. vpdpbusd). + const v128_t int8_input = wasm_i8x16_const(0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0); + const volatile v128_t xint8_input = wasm_i8x16_const(0, 0, 0, -128, 0, 0, -128, 0, 0, -128, 0, 0, -128, 0, 0, 0); // volatile to confuse Clang which otherwise ICE's + const v128_t xint8_output = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(int8_input, xint8_input, wasm_i8x16_const_splat(0)); + + const volatile v128_t overflow_input = wasm_i8x16_const(-128, -128, -128, -128, -128, -128, -1, -1, -1, -1, -128, -128, -1, -1, -1, -1); // volatile to confuse Clang which otherwise ICE's + const v128_t overflow_output = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(wasm_i8x16_const_splat(-128), overflow_input, wasm_i8x16_const_splat(0)); + return !wasm_v128_any_true(wasm_v128_or( + wasm_v128_xor(xint8_output, wasm_i32x4_const_splat(128)), + wasm_v128_xor(overflow_output, wasm_i32x4_const(-65536, -98048, -98048, -130560)))); +} + +// wasm implementation of "_mm_unpacklo_epi8" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i8x16_unpacklo_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23); +} + +// wasm implementation of "_mm_unpacklo_epi16" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i16x8_unpacklo_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 0, 1, 16, 17, 2, 3, 18, 19, 4, 5, 20, 21, 6, 7, 22, 23); +} + +// wasm implementation of "_mm_unpackhi_epi16" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i16x8_unpackhi_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 8, 9, 24, 25, 10, 11, 26, 27, 12, 13, 28, 29, 14, 15, 30, 31); +} + +struct MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD +{ + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 4; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::Strides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + const v128_t ZeroVector = wasm_i64x2_const(0, 0); + const v128_t OnesByteBroadcast = wasm_i8x16_splat(1); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + v128_t ReductionVector = ZeroVector; + + // + // Copy the source bytes to the packed buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 4 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // Accumulate into an intermediate per-row accumulator. + + while (k >= 16) { + + v128_t Bytes = wasm_v128_load(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); + + wasm_v128_store(&D[0], Bytes); + + a += 16; + D += 16; + k -= 16; + } + + if (k >= 8) { + v128_t Bytes = wasm_v128_load64_zero(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); + wasm_v128_store64_lane(&D[0], Bytes, 0); + + a += 8; + D += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + v128_t Bytes = wasm_v128_load64_zero(PaddedMatrixAData); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); + + // + // Copy quads of 8-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t quads = (k + 3) / 4; quads > 0; quads--) { + *((int32_t*)D) = wasm_i32x4_extract_lane(Bytes, 0); + D += 4; + Bytes = wasm_i32x4_shuffle(Bytes, wasm_i32x4_splat(0), 1, 2, 3, 0); + } + } + + // + // Reduce the partial accumulators. + // + + ReductionVector = wasm_i32x4_add(ReductionVector, + wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 2, 3, 2, 3)); + ReductionVector = wasm_i32x4_add(ReductionVector, + wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 1, 0, 1, 0)); + + *RowSumBuffer++ = wasm_i32x4_extract_lane(ReductionVector, 0); + + A += lda; + CountM -= 1; + } +} + + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* D, + v128_t BytesRow0, + v128_t BytesRow1, + v128_t BytesRow2, + v128_t BytesRow3, + v128_t BitFlipVector, + v128_t OnesByteBroadcast, + v128_t ColumnSums[2] +) +{ + v128_t PairsInterleaved0 = wasm_i8x16_unpacklo_relaxed(BytesRow0, BytesRow1); + v128_t PairsInterleaved1 = wasm_i8x16_unpacklo_relaxed(BytesRow2, BytesRow3); + + PairsInterleaved0 = wasm_v128_xor(PairsInterleaved0, BitFlipVector); + PairsInterleaved1 = wasm_v128_xor(PairsInterleaved1, BitFlipVector); + + v128_t QuadsInterleaved0 = wasm_i16x8_unpacklo_relaxed(PairsInterleaved0, PairsInterleaved1); + v128_t QuadsInterleaved1 = wasm_i16x8_unpackhi_relaxed(PairsInterleaved0, PairsInterleaved1); + + ColumnSums[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(QuadsInterleaved0, OnesByteBroadcast, ColumnSums[0]); + ColumnSums[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(QuadsInterleaved1, OnesByteBroadcast, ColumnSums[1]); + + wasm_v128_store(&D[0], QuadsInterleaved0); + wasm_v128_store(&D[16], QuadsInterleaved1); +} + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + const v128_t OnesByteBroadcast = wasm_i8x16_splat(1); + const v128_t BitFlipVector = wasm_i32x4_splat(BIsSigned ? 0 : 0x80808080); + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + v128_t ColumnSums[2]; + + ColumnSums[0] = wasm_i64x2_const(0, 0); + ColumnSums[1] = wasm_i64x2_const(0, 0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK) { + + v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&b[ldb]); + v128_t BytesRow2 = wasm_v128_load64_zero(&b[ldb * 2]); + v128_t BytesRow3 = wasm_v128_load64_zero(&b[ldb * 3]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + b += ldb * 4; + D += 32; + k -= 4; + } + + if (k > 0) { + + v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); + v128_t BytesRow1 = BitFlipVector; + v128_t BytesRow2 = BitFlipVector; + v128_t BytesRow3 = BitFlipVector; + + if (k >= 2) { + BytesRow1 = wasm_v128_load64_zero(&b[ldb]); + } + + if (k >= 3) { + BytesRow2 = wasm_v128_load64_zero(&b[ldb * 2]); + } + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + D += 32; + } + + wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); + wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + v128_t ColumnSums[2]; + uint8_t PaddedMatrixBData[32]; + + wasm_v128_store(&PaddedMatrixBData[0], BitFlipVector); + wasm_v128_store(&PaddedMatrixBData[16], BitFlipVector); + + ColumnSums[0] = wasm_i64x2_const(0, 0); + ColumnSums[1] = wasm_i64x2_const(0, 0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded[16] = bcopy[ldb * 2]; + padded[24] = bcopy[ldb * 3]; + padded++; + bcopy++; + } while (padded < padded_end); + + v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&PaddedMatrixBData[8]); + v128_t BytesRow2 = wasm_v128_load64_zero(&PaddedMatrixBData[16]); + v128_t BytesRow3 = wasm_v128_load64_zero(&PaddedMatrixBData[24]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + b += ldb * 4; + D += 32; + k -= 4; + } + + if (k > 0) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + wasm_v128_store(&PaddedMatrixBData[0], BitFlipVector); + wasm_v128_store(&PaddedMatrixBData[16], BitFlipVector); + + if (k == 3) { + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded[16] = bcopy[ldb * 2]; + padded++; + bcopy++; + } while (padded < padded_end); + } else if (k == 2) { + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + } else { + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + } + + v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&PaddedMatrixBData[8]); + v128_t BytesRow2 = wasm_v128_load64_zero(&PaddedMatrixBData[16]); + v128_t BytesRow3 = wasm_v128_load64_zero(&PaddedMatrixBData[24]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + } + + wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); + wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + + +//-------------------------------------------------------------------------- +// Small helper that performs one (A row) × (8 B columns) FMA step. +//-------------------------------------------------------------------------- +MLAS_FORCEINLINE void DotPairAdd(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec0, ABroadcast, Acc[0]); + Acc[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec1, ABroadcast, Acc[1]); +} + +//-------------------------------------------------------------------------- +// Generic RowCount×8 kernel implementation (RowCount is 6 or 1 at compile‑time) +//-------------------------------------------------------------------------- + +template +static size_t GemmQuantKernelNx8Impl( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t /*CountM — ignored*/, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK; + + // Build row‑wise pointer tables (a[r] & c[r]). + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } + + while (CountN > 0) { + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); + ZeroPointB += 8; + + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_mul(RowSumValues, zp0); + Acc[r][1] = wasm_i32x4_mul(RowSumValues, zp1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][1] = Acc[r][0]; + } + } + + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); // first 4 col sums + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); // next 4 col sums + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_i32x4_add(Acc[r][0], col0); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], col1); + } + ColumnSumBuffer += 8; + + // ------------------------------------------------------------------ + // 2) Broadcast each pair of 8-bit values from the matrix A and multiply + // with the pair of 8-bit values from matrix B, and add the 32-bit + // intermediate into the accumulator registers. + // ------------------------------------------------------------------ + size_t k = PackedCountK; + while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); // broadcast 4 × u8 + a[r] += 4; + } + + v128_t B0 = wasm_v128_load(&B[0]); + v128_t B1 = wasm_v128_load(&B[16]); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAdd(ABroadcast[r], B0, B1, Acc[r]); + } + B += 32; + k -= 1; + } + + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values + // from matrix C. + // ------------------------------------------------------------------ + + if (CountN >= 8) { + // ---- Full 8‑column tile ---- + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) { + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); + } + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + a[r] -= PackedCountK * 4; // Rewind a[r] for next N‑tile (PackedCountK * 4 bytes each). + c[r] += ColBlock; + } + CountN -= ColBlock; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); + } + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; + } + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; + } + } + CountN = 0; + } + } + return RowCount; +} + + +size_t MlasGemmQuantKernel6x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<6>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + + +template <> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode +) +{ + size_t RowsHandled = 0; + if (CountM >= 6) { + RowsHandled = MlasGemmQuantKernel6x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { + MlasGemmQuantOperation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK, + 0, + 6 // multiple of kernel stride M +}; diff --git a/src/lib/qgemm_kernel_wasmsimd.cpp b/src/lib/qgemm_kernel_wasmsimd.cpp index 1f33d77..84f6c6b 100644 --- a/src/lib/qgemm_kernel_wasmsimd.cpp +++ b/src/lib/qgemm_kernel_wasmsimd.cpp @@ -322,181 +322,209 @@ MlasGemmQuantCopyPackB( } } -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmSimd( - v128_t ABroadcast, - const int16_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[8]); - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); +//------------------------------------------------------------------ +// Helper – dot‑product add for i16×i16 → i32 pairs. +//------------------------------------------------------------------ +MLAS_FORCEINLINE void DotPairAddI16(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_add(Acc[0], wasm_i32x4_dot_i16x8(BVec0, ABroadcast)); + Acc[1] = wasm_i32x4_add(Acc[1], wasm_i32x4_dot_i16x8(BVec1, ABroadcast)); } +//------------------------------------------------------------------ +// Generic RowCount×8 kernel (RowCount = 4 or 1) for WASM SIMD. +//------------------------------------------------------------------ -template<> -size_t -MlasGemmQuantKernel( +template +static size_t GemmQuantKernelNx8Impl( const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, int32_t* C, size_t PackedCountK, - size_t CountM, + size_t /*CountM — ignored*/, size_t CountN, size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) + bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - v128_t Accumulators[2]; - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK; // ==2 - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } + while (CountN > 0) { + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); ZeroPointB += 8; - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - - } - else { - - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp0), col0); + Acc[r][1] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp1), col1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(RowSumValues, col0); + Acc[r][1] = wasm_i32x4_add(RowSumValues, col1); + } } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); ColumnSumBuffer += 8; - // - // Broadcast each pair of 16-bit values from the matrix A and multiply + // ---------------------------------------------------------------------- + // 2) Broadcast each pair of 16-bit values from the matrix A and multiply // with the pair of 16-bit values from matrix B, and add the 32-bit // intermediate into the accumulator registers. - // - - const int16_t* a = A; + // ---------------------------------------------------------------------- size_t k = PackedCountK; + while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); + a[r] += 2; + } - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[16], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } + v128_t B0 = wasm_v128_load(B + 0); // cols 0‑3 (8 i16) + v128_t B1 = wasm_v128_load(B + 8); // cols 4‑7 (8 i16) - while (k > 0) { - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAddI16(ABroadcast[r], B0, B1, Acc[r]); + } - a += 2; B += 16; k -= 1; } - // - // Output the accumulator block after optionally accumulating the values + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values // from matrix C. - // - + // ------------------------------------------------------------------ if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - + for (size_t r = 0; r < RowCount; ++r) { if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + c[r] += ColBlock; + a[r] -= PackedCountK * 2; // Rewind a[r] for next N-tile (PackedCountK * 2 elements, 16-bit each). } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + CountN -= 8; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; } - - C[0] = AccumulatorValue; } - CountN = 0; } } + return RowCount; +} + +size_t MlasGemmQuantKernel4x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<4>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - return 1; +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + size_t RowsHandled = 0; + if (CountM >= 4) { + RowsHandled = MlasGemmQuantKernel4x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; } const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { diff --git a/src/lib/qladd.cpp b/src/lib/qladd.cpp index 5dafa17..4bfa074 100644 --- a/src/lib/qladd.cpp +++ b/src/lib/qladd.cpp @@ -552,6 +552,113 @@ MlasQLinearAddKernelHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } +#elif defined(MLAS_TARGET_S390X) +template +static +void +MlasQLinearAddKernelHelper( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + if (N >= 16) { + float ScaleRatio_AC = ScaleA / ScaleC; + float ScaleRatio_BC = ScaleB / ScaleC; + MLAS_FLOAT32X4 VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); + MLAS_FLOAT32X4 VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); + MLAS_FLOAT32X4 VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); + MLAS_FLOAT32X4 vb0_lo, vb0_hi, vb1_lo, vb1_hi; + const uint8_t flip = 128; + MLAS_UNREFERENCED_PARAMETER(flip); + __vector unsigned char vmask = reinterpret_cast<__vector unsigned char>(vec_splats(flip)); + __vector signed short vmask1 = reinterpret_cast<__vector signed short>(vec_splats((short)flip)); + + if (IsScalarB) { + vb0_lo = MlasBroadcastFloat32x4((float)*InputB); + VectorFixedPart = __builtin_s390_vfmasb(vb0_lo, VectorScaleRatio_BC, VectorFixedPart); + } + while (N >= 16) { + MLAS_INT32X4 r_lo, r_hi; + MLAS_FLOAT32X4 va_lo, va_hi; + MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_AC); + MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_BC); + auto va = MlasPackL8(InputA, vmask); + auto vshort = vec_unpackh(va); + vshort = MlasPackS16(vshort, vmask1); + auto va1 = vec_unpackl(vshort); + auto va0 = vec_unpackh(vshort); + va_lo = vec_float(va0); + va_hi = vec_float(va1); + if (!IsScalarB) { + auto vb = MlasPackL8(InputB, vmask); + vshort = vec_unpackh(vb); + vshort = MlasPackS16(vshort, vmask1); + auto vb1 = vec_unpackl(vshort); + auto vb0 = vec_unpackh(vshort); + vb0_lo = vec_float(vb0); + vb0_hi = vec_float(vb1); + vshort = vec_unpackl(vb); + vshort = MlasPackS16(vshort, vmask1); + vb1 = vec_unpackl(vshort); + vb0 = vec_unpackh(vshort); + vb1_lo = vec_float(vb0); + vb1_hi = vec_float(vb1); + InputB += 16; + } + va_lo = va_lo * VectorScaleRatio_AC; + va_hi = va_hi * VectorScaleRatio_AC; + if (IsScalarB) { + r_lo = vec_signed(vec_round(va_lo + VectorFixedPart)); + r_hi = vec_signed(vec_round(va_hi + VectorFixedPart)); + } else { + vb0_lo = vb0_lo * VectorScaleRatio_BC; + vb0_hi = vb0_hi * VectorScaleRatio_BC; + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo + vb0_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi + vb0_hi)); + } + const auto vc0 = vec_packs(r_lo, r_hi); + vshort = vec_unpackl(va); + vshort = MlasPackS16(vshort, vmask1); + va1 = vec_unpackl(vshort); + va0 = vec_unpackh(vshort); + va_lo = vec_float(va0); + va_hi = vec_float(va1); + va_lo = va_lo * VectorScaleRatio_AC; + va_hi = va_hi * VectorScaleRatio_AC; + if (IsScalarB) { + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi)); + } else { + vb1_lo = vb1_lo * VectorScaleRatio_BC; + vb1_hi = vb1_hi * VectorScaleRatio_BC; + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo + vb1_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi + vb1_hi)); + } + const auto vc1 = vec_packs(r_lo, r_hi); + MLAS_INT32X4 vc = MlasPackS16_128(vc0, vc1); + vec_xst(vc, 0, reinterpret_cast(OutputC)); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(vc); + + N -= 16; + InputA += 16; + OutputC += 16; + } + } + if (N > 0) { + MlasQLinearAddKernelRawHelper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); + } +} #elif defined(MLAS_LSX_INTRINSICS) template diff --git a/src/lib/qladd.h b/src/lib/qladd.h index 9456894..369708e 100644 --- a/src/lib/qladd.h +++ b/src/lib/qladd.h @@ -453,6 +453,101 @@ MlasPackS16_128( return reinterpret_cast(vec_packsu(a, b)); } +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ) +{ + return reinterpret_cast(vec_packs(a, b)); +} +#elif defined(MLAS_TARGET_S390X) +typedef __vector signed char MLAS_INT8; +typedef __vector short MLAS_SHORT; +template +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const DataType* Input, + __vector unsigned char vmask + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const uint8_t* Input, + __vector unsigned char vmask + ) +{ + __vector unsigned char va = vec_xl(0,Input); + return reinterpret_cast(reinterpret_cast<__vector unsigned char>(va) - vmask); +} + +template <> +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const int8_t* Input, + __vector unsigned char vmask + ) +{ + MLAS_UNREFERENCED_PARAMETER(vmask); + return reinterpret_cast(vec_xl(0,Input)); +} + +template +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ); + +template <> +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ) +{ + return a + b; +} + +template <> +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ) +{ + MLAS_UNREFERENCED_PARAMETER(b); + return a; +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ) +{ + return reinterpret_cast(vec_packsu(a, b)); +} + template <> MLAS_FORCEINLINE MLAS_INT32X4 diff --git a/src/lib/qlgavgpool.cpp b/src/lib/qlgavgpool.cpp index f0d2b48..746d722 100644 --- a/src/lib/qlgavgpool.cpp +++ b/src/lib/qlgavgpool.cpp @@ -15,7 +15,7 @@ Module Name: --*/ #include "mlasi.h" -#include +#include size_t MLASCALL diff --git a/src/lib/qlmul.cpp b/src/lib/qlmul.cpp index 4a6d57d..e518483 100644 --- a/src/lib/qlmul.cpp +++ b/src/lib/qlmul.cpp @@ -384,6 +384,98 @@ MlasQLinearMulKernel( MLAS_UNREFERENCED_PARAMETER(ScaleBVector); MLAS_UNREFERENCED_PARAMETER(ValueBVector); } +#elif defined(MLAS_ZVECTOR_INTRINSICS) + +template +static +void +MlasQLinearMulKernel( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); + const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); + + auto ZeroPointAVector = vec_splats(int32_t(ZeroPointA)); + auto ZeroPointBVector = vec_splats(int32_t(ZeroPointB)); + auto ZeroPointCVector = vec_splats(float(ZeroPointC)); + + auto ScaleAVector = vec_splats(ScaleA); + auto ScaleBVector = vec_splats(ScaleB); + auto ScaleCVector = vec_splats(ScaleC); + + auto MinimumVector = vec_splats(MinimumValue); + auto MaximumVector = vec_splats(MaximumValue); + + float ValueB; + __vector float ValueBVector; + + if (IsScalarB) { + ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); + ValueBVector = vec_splats(ValueB); + } + + while (N >= 4) { + __vector int32_t IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; + auto IntegerVector = IntegerAVector - ZeroPointAVector; + auto ValueAVector = ScaleAVector * vec_float(IntegerVector); + + if (!IsScalarB) { + __vector int32_t IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; + IntegerVector = IntegerBVector - ZeroPointBVector; + ValueBVector = ScaleBVector * vec_float(IntegerVector); + } + + auto ValueCVector = ValueAVector * ValueBVector / ScaleCVector; + ValueCVector = vec_min(vec_max(ValueCVector, MinimumVector), MaximumVector); + ValueCVector = vec_round(ValueCVector + ZeroPointCVector); + + auto IntegerValueCVector = vec_signed(ValueCVector); + OutputC[0] = (DataType) IntegerValueCVector[0]; + OutputC[1] = (DataType) IntegerValueCVector[1]; + OutputC[2] = (DataType) IntegerValueCVector[2]; + OutputC[3] = (DataType) IntegerValueCVector[3]; + + OutputC += 4; + InputA += 4; + InputB += 4; + + N -= 4; + + // Suppress wrong GCC warnings + MLAS_UNREFERENCED_PARAMETER(ValueAVector); + } + + while (N > 0) { + float ValueA = ScaleA * (int32_t(*InputA) - ZeroPointA); + if (!IsScalarB) { + ValueB = ScaleB * (int32_t(*InputB) - ZeroPointB); + } + float ValueC = (ValueA * ValueB) / ScaleC; + ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue); + + *OutputC = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC); + + InputA++; + InputB++; + OutputC++; + N--; + } + + // Suppress wrong GCC warnings + MLAS_UNREFERENCED_PARAMETER(ScaleAVector); + MLAS_UNREFERENCED_PARAMETER(ScaleBVector); + MLAS_UNREFERENCED_PARAMETER(ValueBVector); +} #elif defined(MLAS_LSX_INTRINSICS) diff --git a/src/lib/qnbitgemm.cpp b/src/lib/qnbitgemm.cpp new file mode 100644 index 0000000..f34128d --- /dev/null +++ b/src/lib/qnbitgemm.cpp @@ -0,0 +1,1198 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasQNBitGemmBatch, + as well as some SQNBitGemm-related query functions. +--*/ + +#include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +#include + +namespace +{ + +enum QNBitGemmVariant { + SQNBitGemmVariantInvalid = -1, + + // Valid variants + + SQ4BitGemmVariant_CompFp32 = 0, + SQ4BitGemmVariant_CompInt8, + HQ4BitGemmVariant_CompFp16, + HQ4BitGemmVariant_CompInt8, + SQ8BitGemmVariant_CompInt8, + + // End of valid variants + + // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. + // Its value is used as an array size. + SQNBitGemmVariantCount, +}; + +QNBitGemmVariant +GetQNBitGemmVariant( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + if ((BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (BlkBitWidth == 4) { + if (ComputeType == SQNBIT_CompFp32) { + return SQ4BitGemmVariant_CompFp32; + } else if (ComputeType == HQNBIT_CompFp16) { + return HQ4BitGemmVariant_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { + return SQ4BitGemmVariant_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQ4BitGemmVariant_CompInt8; + } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8) { + return SQ8BitGemmVariant_CompInt8; + } + } + } + + return SQNBitGemmVariantInvalid; +} + +} // namespace + +bool MLASCALL +MlasIsQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; + if (Dispatch == nullptr) { + return false; + } + + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + + switch (Variant) { + case SQ4BitGemmVariant_CompFp32: { + return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && + Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case HQ4BitGemmVariant_CompFp16: { + return Dispatch->HQ4BitGemmPackQuantBData != nullptr && + Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && + Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; + } + case SQ4BitGemmVariant_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + return + (Dispatch->SQ4BitGemmKernel_Packed_CompInt8 != nullptr && Dispatch->QuantizeA_Packed_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); + } + case SQ8BitGemmVariant_CompInt8: { + return Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr && + Dispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr && + Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr; + } + default: { + return false; + } + } +} + +namespace +{ + +size_t +QNBitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceSize == nullptr) { + return 0; + } + + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType, BlkBitWidth); + } + + return 0; +} + +size_t +QNBitGemmPerGemmWorkspaceAlignment( + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceAlignment == nullptr) { + return 1; + } + + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + } + + return 1; +} + +size_t +QNBitGemmPerGemmWorkspaceStride( + size_t M, + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, HasZeroPoint, ComputeType); + const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + return MlasDivRoundup(Size, Alignment) * Alignment; +} + +} // namespace + +size_t MLASCALL +MlasQNBitGemmBatchWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const size_t PerGemmWorkspaceStride = + QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, HasZeroPoint, ComputeType); + if (PerGemmWorkspaceStride == 0) { + return 0; + } + + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + + return WorkspaceSize + Alignment - 1; +} + +size_t MLASCALL +MlasQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; + } + + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q4BitGemmPackQuantBDataSize( + N, K, BlkLen, HasZeroPoint, ComputeType + ); + } else if (BlkBitWidth == 8 && Dispatch->Q8BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q8BitGemmPackQuantBDataSize( + N, K, BlkLen, HasZeroPoint, ComputeType + ); + } + + return 0; +} + +struct PerGemmQuantAWorkspace { + PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) + : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + QuantData = (std::byte*)PerGemmWorkspace; + QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); + BlockSum = QuantScale + M * BlockCountK; + } + std::byte* QuantData; // NxBlockCountKxBlkLen + float* QuantScale; // NxBlockCountK + float* BlockSum; // NxBlockCountK + void* PerGemmWorkspace_; // memory for above data + size_t M_, BlockCountK_, BlkLen_; +}; + +void MLASCALL +MlasQNBitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBDataAndOrBlkSumWorkspace, + const void* QuantBScale, + bool HasZeroPoint, + const void* QuantBZeroPoint, + MLAS_THREADPOOL* ThreadPool +) +{ + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, false); + Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + HasZeroPoint, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) { + Dispatch->HQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + //assert(QuantBScale == nullptr); + //assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, + BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); + Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + HasZeroPoint, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } + } +} + +bool MLASCALL +MlasQNBitGemmScalesPacked( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + bool HasZeroPoint +) { +#ifdef MLAS_TARGET_ARM64 + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8) { + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + return UsePacked && UsePacked(K, BlkLen, HasZeroPoint); + } +#else + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(ComputeType); + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); +#endif // MLAS_TARGET_ARM64 + return false; +} + +namespace +{ + +MLAS_FORCEINLINE void +AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +void +SQ4BitGemm_CompFp32( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( + BlkLen, + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +void +HQ4BitGemm_CompFp16( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const size_t k_blk_num = MlasDivRoundup(K, BlkLen); + const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ldb = k_blk_num * BlkLen; + const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blk_num); + + const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * qldb; + const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes; + const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias; + + // 32N is the sweet spot of cache utilization. It is machine dependent though. + constexpr size_t StrideM = 2; + constexpr size_t StrideN = 32; + + // TODO(fajin): move allocation up to the op. + size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num + ); + + const MLAS_FP16* a = A; + MLAS_FP16* c = C; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc + ); + } + + a += countM * lda; + c += countM * ldc; + } + + QuantBData += countN * qldb; + QuantBScale += countN * k_blk_num; + QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr; + Bias = Bias ? Bias + countN : nullptr; + C += countN; + } +} + +void +SQ4BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + const auto SQ4BitGemm = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_Packed_CompInt8; + if (UsePacked && SQ4BitGemm && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + const std::byte* QuantA = static_cast(PerGemmWorkspace); + SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData, + DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K, + DataParams->ldc, DataParams->Bias); + return; + } + +#ifdef MLAS_TARGET_AMD64_IX86 + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 4 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#else + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#endif + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, + RowsHandled, CountN, ldc + ); + } + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +#ifdef MLAS_TARGET_AMD64_IX86 + else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( + BlkLen, + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } +#endif + } +} + +void +SQ8BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 8; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 16 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + const float* BlkUnsignedQuantAZeroPointCorrection = + DataParams->BlkUnsignedQuantAZeroPointCorrection ? DataParams->BlkUnsignedQuantAZeroPointCorrection + RangeStartN * k_blks : nullptr; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + const float* blk_unsigned_quant_A_zp_correction = BlkUnsignedQuantAZeroPointCorrection ? + BlkUnsignedQuantAZeroPointCorrection + n * k_blks : nullptr; + GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( + BlkLen, + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum, + blk_unsigned_quant_A_zp_correction + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + } +} + +template +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth +); + +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth +) +{ + MLAS_UNREFERENCED_PARAMETER(N); + + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8; + const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + // TODO: try parallel on BatchN * M threads because BatchN is usually 1. + if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr); + }); + } else { + // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms + if (BlkBitWidth == 4) { + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } + } else if (BlkBitWidth == 8) { + if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } + } + } +} + +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth +) { + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(Workspace); + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); +} + +template +using InitializeWorkspaceFn = std::function* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth +)>; + +template +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant); + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case SQ4BitGemmVariant_CompInt8: + case SQ8BitGemmVariant_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case HQ4BitGemmVariant_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template +using QNBitGemmFn = std::function* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +)>; + +template +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant); + +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case SQ4BitGemmVariant_CompFp32: + return SQ4BitGemm_CompFp32; + case SQ4BitGemmVariant_CompInt8: + return SQ4BitGemm_CompInt8; + case SQ8BitGemmVariant_CompInt8: + return SQ8BitGemm_CompInt8; + default: + return nullptr; + } +} + +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case HQ4BitGemmVariant_CompFp16: + return HQ4BitGemm_CompFp16; + default: + return nullptr; + } +} +} // namespace + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +) +{ + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + // + // Ensure `Workspace` has correct alignment. + // + if (Workspace != nullptr) { + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); + Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) + ); + } + + const bool has_zp_input = DataParams->QuantBZeroPoint; + const size_t PerGemmWorkspaceStride = + QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, has_zp_input, ComputeType); + + if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); + InitializeWorkspaceOperation != nullptr) { + InitializeWorkspaceOperation( + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool, BlkBitWidth + ); + } + + const auto ComputeOperation = GetQNBitGemm(Variant); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + const auto* Data = &DataParams[gemm_i]; + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + } + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + const auto* Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else { + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } + }); +} + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/src/lib/qnbitgemm.h b/src/lib/qnbitgemm.h new file mode 100644 index 0000000..7ec80c6 --- /dev/null +++ b/src/lib/qnbitgemm.h @@ -0,0 +1,566 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qnbitgemm.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. + + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + +template +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen, bool QuantAUnsigned) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); +#if defined(MLAS_TARGET_AMD64_IX86) + // avx512 requires alignment on a 64-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#elif defined (MLAS_TARGET_ARM64) + // Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and + // there is enough memory allocated to support this alignment. + // See QNBitGemmPackQuantBDataSize(). + // When bit width is 4, there is no alignment guarantee. + // TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to + // simpify this logic and make code here cleaner ? + if constexpr (BlkBitWidth == 8) { + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + } + else { + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + } +#else + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; +#endif + + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + + if (QuantAUnsigned) { + BlkUnsignedQuantAZeroPointCorrection = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + BlkUnsignedQuantAZeroPointCorrection = (T*)MlasAlignAddress(BlkUnsignedQuantAZeroPointCorrection, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)BlkUnsignedQuantAZeroPointCorrection + BlkSumSize); + } else { + BlkUnsignedQuantAZeroPointCorrection = nullptr; + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + } + } + + std::byte* PackedQuantBData; + T* PackedQuantBScale; + T* QuantBBlkSum; + T* BlkUnsignedQuantAZeroPointCorrection; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +// +// Kernel dispatch structure. +// + +struct MLAS_QNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q8BitGemmPackQuantBDataSize_Fn* Q8BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; + + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + + typedef void(SQ8BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ8BitGemmPackQuantBDataAndSumBlk_Fn* SQ8BitGemmPackQuantBDataAndBlkSum = nullptr; + + // + // Workspace size calculation function prototypes. + // + + /** + * @brief Gets the required size in bytes of the per-GEMM intermediate workspace. + * Returns a size of zero if no intermediate workspace is needed. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(QNBitGemmPerGemmWorkspaceSize_Fn)( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth + ); + + QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; + + /** + * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. + * + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(QNBitGemmPerGemmWorkspaceAlignment_Fn)( + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + QNBitGemmPerGemmWorkspaceAlignment_Fn* QNBitGemmPerGemmWorkspaceAlignment = nullptr; + + // + // SQNBIT_CompFp32 kernel function prototypes. + // + + /** + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; + + // + // SQNBIT_CompInt8 kernel function prototypes. + // + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * A should be packed using QuantizeA_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param PackedQuantBData Supplies the packed quantized B matrix data. + * @param[out] C Supplies the output C matrix. + * @param RangeStartM Start of M range. + * @param RangeCountM Number of rows of A and C. + * @param RangeStartN Start of N range. + * @param RangeCountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param ldc Number of elements between adjacent rows of C. + */ + typedef void(SQ4BitGemmKernel_Packed_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_Packed_CompInt8_Fn* SQ4BitGemmKernel_Packed_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkUnsignedQuantAZeroPointCorrection Supplies the optional input to de-bias the Gemm output to account for the +128 bias + addition when the activation input A is quantized to uint8. + */ + typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection + ); + + SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountM Number of rows of A and C to process, an upper bound. + * @param CountN Number of columns of B and C to process. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks in one row of A and one column of B. + * @param ldc Number of elements between adjacent rows of C. + * @param Bias Bias vector of length N. + * + * @return The number of rows of A and C that were processed, at most CountM. + */ + typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + + /** + * @brief Whether to use SQ4BitGemmKernel_Packed_CompInt8 for this problem. + */ + typedef bool(UsePacked_CompInt8_Fn)( + size_t K, + size_t BlkLen, + bool HasZp + ); + + UsePacked_CompInt8_Fn* UsePacked_CompInt8 = nullptr; + + /** + * @brief Block quantize values from matrix A from floats to quantized 8-bit integers. + * Used in conjunction with SQ4BitGemmKernel_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeA_Packed_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA + ); + + QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + + /** + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeARow_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA + ); + + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) + ); + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; +}; diff --git a/src/lib/qnbitgemm_kernel_neon.cpp b/src/lib/qnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000..ba2b68e --- /dev/null +++ b/src/lib/qnbitgemm_kernel_neon.cpp @@ -0,0 +1,602 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qnbitgemm_kernel_neon.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include "qnbitgemm_kernel_neon.h" + +#include + +#include +#include +#include + +#include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +#ifdef USE_KLEIDIAI +#include "kai/kai_common.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_ukernel_interface.h" +#endif + +namespace sqnbitgemm_neon +{ + +namespace +{ + +// +// Quantized B data packing function implementation. +// + +template +size_t +QNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + if constexpr (BlkBitWidth == 4) { +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType +#endif + +#ifdef USE_KLEIDIAI + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + } else +#endif + { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + } else { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (ComputeType == SQNBIT_CompInt8) { + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // align on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + if constexpr (QuantAUnsigned) { + // 2 block sum + return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize; + } else { + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } + } else { + return PackedQuantBDataSize; + } + } +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == SQNBIT_CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte*, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(QuantBScaleBegin); + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); +#endif + assert(BlkLen >= 16 && BlkLen % 16 == 0); + +#ifdef USE_KLEIDIAI + if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData; + + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t scales_len = N * BlockCountK; + std::vector scales(scales_len); + for (size_t i = 0; i < scales_len; i++) { + const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); + scales[i] = *i32 >> 16; + } + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, + reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, + nullptr, scales.data(), BlockCountK * sizeof(uint16_t), + PackedQuantBDataBegin, 0, ¶ms); + } else +#endif + { + std::byte* PackedQuantBDataBegin = reinterpret_cast(PackedQuantB.QuantBWorkspace_); + SQ4BitGemmPackQuantBData(N, K, BlkLen, ComputeType, QuantBDataBegin, PackedQuantBDataBegin, ThreadPool); + } +} + +void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + float* BlkUnsignedQuantAZeroPointCorrectionBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t K, + const size_t BlkLen) +{ + constexpr size_t SubBlkLen = 4; + const size_t BlkCountK = MlasDivRoundup(K, BlkLen); + const size_t SubBlkPerBlk = BlkLen / SubBlkLen; + const size_t StrideN = BlkCountK * BlkLen; + const size_t Iterations = N * BlkCountK; + + // 4 rows x 8 columns pack together, then 4 rows x 4 columns, then per column. + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / BlkCountK; + const size_t c8 = c & (~7), c8_res = c & 7; + const size_t c4 = c & (~3), c4_res = c & 3; + const size_t r_blk = tid % BlkCountK; + size_t r_subblk = r_blk * SubBlkPerBlk; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_blk * BlkLen; + const uint8_t* src8 = reinterpret_cast(src); + + for (size_t i = 0; i < SubBlkPerBlk; ++i, src += SubBlkLen, ++r_subblk) { + if (c8 + 8 <= N) { // full 8 cols + std::byte* dest = + PackedQuantBDataBegin + c8 * StrideN + r_subblk * SubBlkLen * 8 + c8_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else if (c4 + 4 <= N) { // full 4 cols + std::byte* dest = + PackedQuantBDataBegin + c4 * StrideN + r_subblk * SubBlkLen * 4 + c4_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } + } + + if (BlkUnsignedQuantAZeroPointCorrectionBegin) { + const int accu = std::accumulate(src8, src8 + std::min(BlkLen, K - r_blk * BlkLen), 0); + + // for sgemmc + const size_t dst_offset = ((c / 16) * BlkCountK + r_blk) * 16 + c % 16; + BlkUnsignedQuantAZeroPointCorrectionBegin[dst_offset] = static_cast(accu); + } + } + ); +} + +void +Q8ComputePackBlkSum( + const size_t BlkLen, + const size_t N, + const size_t K, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + float* BlockSum2Begin, + MLAS_THREADPOOL* ThreadPool) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n8 = n & (~7), n8_res = n & 7; + const size_t n4 = n & (~3), n4_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlockSum2Begin) { + BlockSum2Begin[dst_offset] = QuantBScale * (static_cast(zp) * std::min(BlkLen, K - k_blk * BlkLen) - BlockSum2Begin[dst_offset]); + } + + // re-arrange scale to the same order as packed data + if (n4 + 4 > N) { // remainder cols + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (n8 + 8 > N) { // full 4 cols + *(QuantBScaleBegin + n4 * BlockCountK + k_blk * 4 + n4_res) = QuantBScale; + } else { // full 8 cols + *(QuantBScaleBegin + n8 * BlockCountK + k_blk * 8 + n8_res) = QuantBScale; + } + }); +} + +/** + * 4 rows x 8 cols pack together, along all K. Then 4 rows x 4 cols, along all K. + * When rows < 4, keep original layout. + * + * dotprod: vdotq_laneq_u32. + * convert quant a from int8 to uint8. zp is 128. + * + * i8mm: vusdotq_laneq_s32. + */ +void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType */, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // Pack the quantized weights + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool, N, K, BlkLen); + } else { + // We ignore the scales and zero points if they are provided when pre-packing the weights as there is + // some "state" associated with 'BlkUnsignedQuantAZeroPointCorrection'. + + // We accumulate the block sum into 'BlkUnsignedQuantAZeroPointCorrection' while packing the weights + // in the previous step. If we were to use 'scales' while pre-packing the weights and if there were no + // zero points, then we would enter 'Q8ComputePackBlkSum' twice - once while pre-packing the weights + // and once while pre-packing the scales which would lead to erroneous 'BlkUnsignedQuantAZeroPointCorrection' + // computation as the buffer is "used" in-place for the "block sum" temporary values (obtained while pre-packing + // the weights) and the actual 'BlkUnsignedQuantAZeroPointCorrection' which will use the scales. + // Hence, to ensure that the piece of logic to calculate 'BlkUnsignedQuantAZeroPointCorrection' is only invoked + // once, we do it while we are pre-packing the scales and ignore any provided 'scales' and 'zero points' while + // pre-packing the weights. + // The flip side is that the user has to ensure that this function is called once each for 'weights', + // 'scales', and 'zero points'. This is a reasonable expectation and hence we go with that design. + + // Pack the block scales + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + // Pack the blksum (and BlkUnsignedQuantAZeroPointCorrection if applicable) + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, N, K, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool); + } + } +} + +// +// Workspace size calculation function implementation. +// + +size_t +QNBitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth +) +{ + MLAS_UNREFERENCED_PARAMETER(N); +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); +#endif + + switch (ComputeType) { + case SQNBIT_CompInt8: { + // workspace buffer is used for block quantization of A to int8 +#ifdef USE_KLEIDIAI + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = + M == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t mr = ukernel.get_mr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + } else +#endif + { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); + return PerGemmWorkspaceSize; + } + } + default: { + return 0; + } + } +} + +size_t +QNBitGemmPerGemmWorkspaceAlignment( + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(BlkLen); + + switch (ComputeType) { + case SQNBIT_CompInt8: { + return Q8BlkAlignment(); + } + default: { + return 1; + } + } +} + +} // namespace + +bool +UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) +{ +#ifdef USE_KLEIDIAI + bool has_dotprod = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); + return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && !HasZp && has_dotprod; +#else + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(HasZp); + return false; +#endif +} + +template +size_t +SQ8BitGemmKernel_BlkSum_CompInt8( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection +) +{ + MlasQ8Int8GemmKernelNeon( + BlkLen, + reinterpret_cast*>(QuantA), + QuantAScale, + reinterpret_cast(QuantBData), + QuantBScale, + C, + CountM, + CountN, + CountK, + Bias, + ldc + ); + + { + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + + if constexpr (QuantAUnsigned) { + { + assert(BlkUnsignedQuantAZeroPointCorrection != nullptr); + float* c_blk = C; + const float* b_blk_sum2 = BlkUnsignedQuantAZeroPointCorrection; + + size_t RowsRemaining = CountM; + const float* a_scale_row = QuantAScale; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_scale_row, b_blk_sum2, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 128.f); + + c_blk += ldc * RowsHandled; + a_scale_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + + return CountM; +} + +} // namespace sqnbitgemm_neon + +// +// Kernel dispatch structure accessor. +// + +const MLAS_QNBIT_GEMM_DISPATCH& +GetMlasQNBitGemmDispatchNeon( + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport +) +{ + // Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static + // MLAS_QNBIT_GEMM_DISPATCH instance. + + static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() { + MLAS_QNBIT_GEMM_DISPATCH d; + + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<4, false>; + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, true>; + d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ8BitGemmPackQuantBDataAndBlkSum; + + d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + + if (InitializeWithDotSupport) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; + d.UsePacked_CompInt8 = sqnbitgemm_neon::UsePacked_CompInt8; + + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + +#ifdef USE_KLEIDIAI + d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; + d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; +#endif + } + + if (InitializeWithI8MMSupport) { + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, false>; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + + return d; + }(); + + return MlasQNBitGemmDispatchNeon; +} diff --git a/src/lib/sqnbitgemm_kernel_neon.h b/src/lib/qnbitgemm_kernel_neon.h similarity index 50% rename from src/lib/sqnbitgemm_kernel_neon.h rename to src/lib/qnbitgemm_kernel_neon.h index ef9345d..c8be42b 100644 --- a/src/lib/sqnbitgemm_kernel_neon.h +++ b/src/lib/qnbitgemm_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + qnbitgemm_kernel_neon.h Abstract: @@ -23,6 +23,7 @@ Module Name: #include #include +#include "mlas_qnbit.h" #include "mlasi.h" namespace sqnbitgemm_neon @@ -30,13 +31,13 @@ namespace sqnbitgemm_neon // // Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. +// Refer to the prototypes in qnbitgemm.h for documentation. // These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// MLAS_QNBIT_GEMM_DISPATCH structure and also be implemented in separate // files. // -// CompFp32 declarations +// SQNBIT_CompFp32 declarations void SQ4BitGemmM1Kernel_CompFp32( @@ -53,7 +54,7 @@ SQ4BitGemmM1Kernel_CompFp32( ); void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -64,7 +65,55 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// CompInt8 declarations +// HQNBIT_CompFp16 declarations +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +); + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +); + +#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) + +// SQNBIT_CompInt8 declarations + +bool +UsePacked_CompInt8( + size_t K, + size_t BlkLen, + bool HasZp +); void QuantizeARow_CompInt8( @@ -74,6 +123,36 @@ QuantizeARow_CompInt8( std::byte* QuantA ); +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +using QuantAType = typename std::conditional::type; + +template +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const QuantAType* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +); + size_t SQ4BitGemmKernel_CompInt8( size_t BlkLen, @@ -90,6 +169,35 @@ SQ4BitGemmKernel_CompInt8( const float* Bias ); +#ifdef USE_KLEIDIAI +void +QuantizeA_Packed_CompInt8( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA +); + +void +SQ4BitGemmKernel_Packed_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float *Bias +); +#endif + +bool +UseKleidiAI(size_t K, size_t BlkLen, bool HasZp); + // // General helpers. // diff --git a/src/lib/quantize.cpp b/src/lib/quantize.cpp index ae638fa..c5bf2b4 100644 --- a/src/lib/quantize.cpp +++ b/src/lib/quantize.cpp @@ -766,7 +766,7 @@ MlasQuantizeLinear( #else -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template<> void @@ -902,7 +902,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_POWER) +#if !defined(MLAS_TARGET_POWER) && !defined(MLAS_TARGET_S390X) template void MLASCALL @@ -1681,6 +1681,204 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_TARGET_S390X) + +template +void +MLASCALL +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; + float MinimumValue = float(std::numeric_limits::lowest() - ZeroPoint); + float MaximumValue = float(std::numeric_limits::max() - ZeroPoint); + + auto PerMatrixScaleVector = vec_splats(PerMatrixScaleValue); + auto MinimumVector = vec_splats(MinimumValue); + auto MaximumVector = vec_splats(MaximumValue); + auto ZeroPointVector = vec_splats(int32_t(ZeroPoint)); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(PerMatrixScaleVector); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // Process 16 cols at a time + + while (n >= 16) { + + auto IntegerVector0 = vec_xl(0, &RowInput[0]); + auto IntegerVector1 = vec_xl(0, &RowInput[4]); + auto IntegerVector2 = vec_xl(0, &RowInput[8]); + auto IntegerVector3 = vec_xl(0, &RowInput[12]); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = IntegerVector0 + vec_xl(0, &bias[0]); + IntegerVector1 = IntegerVector1 + vec_xl(0, &bias[4]); + IntegerVector2 = IntegerVector2 + vec_xl(0, &bias[8]); + IntegerVector3 = IntegerVector3 + vec_xl(0, &bias[12]); + bias += 16; + } + + auto FloatVector0 = vec_float(IntegerVector0); + auto FloatVector1 = vec_float(IntegerVector1); + auto FloatVector2 = vec_float(IntegerVector2); + auto FloatVector3 = vec_float(IntegerVector3); + + if (scale != nullptr) { + FloatVector0 = FloatVector0 * vec_xl(0, &scale[0]); + FloatVector1 = FloatVector1 * vec_xl(0, &scale[4]); + FloatVector2 = FloatVector2 * vec_xl(0, &scale[8]); + FloatVector3 = FloatVector3 * vec_xl(0, &scale[12]); + scale += 16; + } else { + FloatVector0 = FloatVector0 * PerMatrixScaleVector; + FloatVector1 = FloatVector1 * PerMatrixScaleVector; + FloatVector2 = FloatVector2 * PerMatrixScaleVector; + FloatVector3 = FloatVector3 * PerMatrixScaleVector; + } + + FloatVector0 = vec_max(FloatVector0, MinimumVector); + FloatVector1 = vec_max(FloatVector1, MinimumVector); + FloatVector2 = vec_max(FloatVector2, MinimumVector); + FloatVector3 = vec_max(FloatVector3, MinimumVector); + + FloatVector0 = vec_min(FloatVector0, MaximumVector); + FloatVector1 = vec_min(FloatVector1, MaximumVector); + FloatVector2 = vec_min(FloatVector2, MaximumVector); + FloatVector3 = vec_min(FloatVector3, MaximumVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + auto IntegerOutVector0 = vec_signed(FloatVector0); + auto IntegerOutVector1 = vec_signed(FloatVector1); + auto IntegerOutVector2 = vec_signed(FloatVector2); + auto IntegerOutVector3 = vec_signed(FloatVector3); + + IntegerOutVector0 = IntegerOutVector0 + ZeroPointVector; + IntegerOutVector1 = IntegerOutVector1 + ZeroPointVector; + IntegerOutVector2 = IntegerOutVector2 + ZeroPointVector; + IntegerOutVector3 = IntegerOutVector3 + ZeroPointVector; + + auto ShortVector0 = vec_pack(IntegerOutVector0, IntegerOutVector1); + auto ShortVector1 = vec_pack(IntegerOutVector2, IntegerOutVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(CharVector); + + vec_xst(CharVector, 0, (int8_t *) RowOutput); + RowOutput += 16; + n -= 16; + } + + while (n >= 4) { + int8_t OutputBuffer[16]; + + auto IntegerVector = vec_xl(0, &RowInput[0]); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = IntegerVector + vec_xl(0, &bias[0]); + bias += 4; + } + + auto FloatVector = vec_float(IntegerVector); + + if (scale != nullptr) { + FloatVector = FloatVector * vec_xl(0, scale); + scale += 4; + } else { + FloatVector = FloatVector * PerMatrixScaleVector; + } + + FloatVector = vec_max(FloatVector, MinimumVector); + FloatVector = vec_min(FloatVector, MaximumVector); + FloatVector = vec_round(FloatVector); + + auto IntegerOutVector = vec_signed(FloatVector); + IntegerOutVector = IntegerOutVector + ZeroPointVector; + + auto ShortVector = vec_pack(IntegerOutVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(CharVector); + + vec_xst(CharVector, 0, &(OutputBuffer[0])); + memcpy(RowOutput, OutputBuffer, 4); + + RowOutput += 4; + n -= 4; + } + + while (n > 0) { + auto IntegerValue = RowInput[0]; + RowInput += 1; + + if (bias != nullptr) { + IntegerValue += bias[0]; + bias += 1; + } + + float FloatValue = float(IntegerValue); + float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue; + + FloatValue *= ScaleValue; + FloatValue = std::max(FloatValue, MinimumValue); + FloatValue = std::min(FloatValue, MaximumValue); + + IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - + MLAS_ROUNDING_BIAS_MAGIC_BITS; + + *RowOutput++ = OutputType(IntegerValue + ZeroPoint); + + n -= 1; + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #elif defined(MLAS_LSX_INTRINSICS) template @@ -1704,8 +1902,8 @@ MlasRequantizeOutput( float min_f = float(std::numeric_limits::lowest() - ZeroPoint); float max_f = float(std::numeric_limits::max() - ZeroPoint); const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); - const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); - const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){min_f,min_f,min_f,min_f}); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){max_f,max_f,max_f,max_f}); const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); if (nullptr != Bias) { diff --git a/src/lib/rotary_embedding.cpp b/src/lib/rotary_embedding.cpp new file mode 100644 index 0000000..63e0a7f --- /dev/null +++ b/src/lib/rotary_embedding.cpp @@ -0,0 +1,108 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.cpp + +Abstract: + + This module implements rotary embedding kernels for fp32/16. + +--*/ + +#include "rotary_embedding.h" + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +) { + const size_t half_rotary_emb_dim = rotary_emb_dim / 2; + size_t cache_idx = 0; + bool sign = false; + size_t j = 0; + for (size_t i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i >= half_rotary_emb_dim); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); + } +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->SRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); + return; + } + + dispatch->SRope(input, sin_data, cos_data, dim, interleaved, output); +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->HRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); + return; + } + + dispatch->HRope(input, sin_data, cos_data, dim, interleaved, output); +} + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const float* input_data, + const float* sin_data, + const float* cos_data, + size_t rotary_emb_dim, + bool interleaved, + float* output_data +); diff --git a/src/lib/rotary_embedding.h b/src/lib/rotary_embedding.h new file mode 100644 index 0000000..c017ece --- /dev/null +++ b/src/lib/rotary_embedding.h @@ -0,0 +1,57 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing rotary embedding. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_ROPE_DISPATCH { + // rotary embedding kernel for fp32 + typedef void(SRope_Fn)( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output + ); + + SRope_Fn* SRope = nullptr; + + // rotary embedding kernel for fp16 + typedef void(HRope_Fn)( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output + ); + + HRope_Fn* HRope = nullptr; +}; + +template +void MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +); diff --git a/src/lib/rotary_embedding_kernel_avx2.cpp b/src/lib/rotary_embedding_kernel_avx2.cpp new file mode 100644 index 0000000..e5d1327 --- /dev/null +++ b/src/lib/rotary_embedding_kernel_avx2.cpp @@ -0,0 +1,311 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.cpp + +Abstract: + + This module implements the rotary embedding kernels for AVX2 supported h/w. + +--*/ + + +#include + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +namespace rope_avx2 { + +namespace { + +typedef __m256 float32x8_t; +static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + +template +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +); + +float32x8_t +load_fp16_and_convert_to_fp32(const MLAS_FP16* input) +{ + __m128i fp16 = _mm_lddqu_si128(reinterpret_cast(input)); + return _mm256_cvtph_ps(fp16); +} + +void +convert_to_fp16_and_store(MLAS_FP16* dst_fp16, const __m256 output) +{ + __m128i fp16_chunk = _mm256_cvtps_ph(output, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_fp16), fp16_chunk); +} + +template <> +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +) +{ + // ?cast input -> const unsigned short* + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = load_fp16_and_convert_to_fp32(input + i); + float32x8_t imag = load_fp16_and_convert_to_fp32(input + j); + float32x8_t sin_val = load_fp16_and_convert_to_fp32(sin_data + i); + float32x8_t cos_val = load_fp16_and_convert_to_fp32(cos_data + i); + // Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + // Store back into non interleaved format + convert_to_fp16_and_store(output + i, real_out); + convert_to_fp16_and_store(output + j, imag_out); + } + for (; i < half_dim; i++, j++) { + float real = input[i].ToFloat(); + float imag = input[j].ToFloat(); + float sin_val = sin_data[i]; + float cos_val = cos_data[i]; + output[i] = MLAS_FP16(real * cos_val - imag * sin_val); + output[j] = MLAS_FP16(real * sin_val + imag * cos_val); + } +} + +template <> +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +) +{ + // ?cast input -> const unsigned short* + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = load_fp16_and_convert_to_fp32(input + i); + float32x8_t x1 = load_fp16_and_convert_to_fp32(input + i + 8); + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = load_fp16_and_convert_to_fp32(sin_data + i / 2); + float32x8_t cos_val = load_fp16_and_convert_to_fp32(cos_data + i / 2); + // Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + // Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + + // Store back into non interleaved format + convert_to_fp16_and_store(output + i, y0); + convert_to_fp16_and_store(output + i + 8, y1); + } + + for (; i < dim; i++) { + size_t cache_idx = i / 2; + bool sign = i & 1; + size_t j = sign ? i - 1 : i + 1; + + float output_data_i = input[i].ToFloat() * cos_data[cache_idx].ToFloat(); + float input_data_j = input[j].ToFloat(); + float sin_data_cache_idx = sin_data[cache_idx].ToFloat(); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output[i] = MLAS_FP16(output_data_i); + } +} + +template +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +); + +template <> +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = _mm256_loadu_ps(input + i); + float32x8_t imag = _mm256_loadu_ps(input + j); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_storeu_ps(output + i, real_out); + _mm256_storeu_ps(output + j, imag_out); + } + if (half_dim - i != 0) { + size_t rem = half_dim - i; + const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); + //Use a mask to load the remaining input values + float32x8_t real = _mm256_maskload_ps(input + i, mask); + float32x8_t imag = _mm256_maskload_ps(input + j, mask); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_maskstore_ps(output + i, mask, real_out); + _mm256_maskstore_ps(output + j, mask, imag_out); + } +} + +template <> +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = _mm256_loadu_ps(input + i); + float32x8_t x1 = _mm256_loadu_ps(input + i + 8); + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_storeu_ps(output + i, y0); + _mm256_storeu_ps(output + i + 8, y1); + } + if (dim - i != 0) { + size_t rem = dim - i; + const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem))); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); + float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask + float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + // Use masked loads for sin/cos data to avoid reading beyond buffer bounds + size_t cos_sin_rem = rem / 2; + const __m256i cos_sin_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - cos_sin_rem)); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i / 2, cos_sin_mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i / 2, cos_sin_mask); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_maskstore_ps(output + i, mask0, y0); + _mm256_maskstore_ps(output + i + 8, mask1, y1); + } +} + +} // rope_avx2 namespace + +void +RopeKernel_Avx2_fp32( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin_data); + const auto* cos_impl = reinterpret_cast(cos_data); + auto* output_impl = reinterpret_cast(output); + + if (interleaved) { + RopeKernel_Avx2_fp32_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Avx2_fp32_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +void +RopeKernel_Avx2_fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +) +{ + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + if (interleaved) { + RopeKernel_Avx2_fp16_Impl(input, sin_data, cos_data, dim, output); + } else { + RopeKernel_Avx2_fp16_Impl(input, sin_data, cos_data, dim, output); + } +} +} + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() { + MLAS_ROPE_DISPATCH d; + d.SRope = rope_avx2::RopeKernel_Avx2_fp32; + d.HRope = rope_avx2::RopeKernel_Avx2_fp16; + return d; +}(); diff --git a/src/lib/rotary_embedding_kernel_avx2.h b/src/lib/rotary_embedding_kernel_avx2.h new file mode 100644 index 0000000..c08833e --- /dev/null +++ b/src/lib/rotary_embedding_kernel_avx2.h @@ -0,0 +1,47 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on for AVX2 enabled h/w. + +--*/ + +#pragma once + + + +#include "mlasi.h" + +namespace rope_avx2 { + +// Rotary embedding kernel for FP32. Embed one hidden state vector. +void +RopeKernel_Avx2( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +); + +void +RopeKernel_Avx2_fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_avx2 diff --git a/src/lib/rotary_embedding_kernel_neon.cpp b/src/lib/rotary_embedding_kernel_neon.cpp new file mode 100644 index 0000000..e59a95c --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.cpp + +Abstract: + + This module implements the rotary embedding kernels for ARM NEON. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.HRope = rope_neon::RopeKernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/rotary_embedding_kernel_neon.h b/src/lib/rotary_embedding_kernel_neon.h new file mode 100644 index 0000000..8153f65 --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace rope_neon { + +// Rotary embedding kernel for fp16. Embed one hidden state vector. +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_neon diff --git a/src/lib/rotary_embedding_kernel_neon_fp16.cpp b/src/lib/rotary_embedding_kernel_neon_fp16.cpp new file mode 100644 index 0000000..3a93723 --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -0,0 +1,279 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 rotary embedding kernels for ARM NEON. + +--*/ + +#include +#include + +#include "fp16_common.h" +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +namespace rope_neon { + +namespace { + +template +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +); + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + if (i + 7 < half_dim) { + float16x8_t real = MlasLoadFloat16x8(input + i); + float16x8_t imag = MlasLoadFloat16x8(input + j); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + for (; i + 15 < half_dim; i += 8, j += 8) { + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + real = MlasLoadFloat16x8(input + i + 8); + imag = MlasLoadFloat16x8(input + j + 8); + sin_val = MlasLoadFloat16x8(sin + i + 8); + cos_val = MlasLoadFloat16x8(cos + i + 8); + } + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + i += 8, j += 8; + } + for (; i + 3 < half_dim; i += 4, j += 4) { + float16x4_t real = MlasLoadFloat16x4(input + i); + float16x4_t imag = MlasLoadFloat16x4(input + j); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x4(output + i, real_out); + MlasStoreFloat16x4(output + j, imag_out); + } + if (half_dim - i == 3) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + real = MlasLoadLaneFloat16x4<2>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out); + } else if (half_dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + } else if (half_dim - i == 1) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + } +} + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + size_t i = 0; + if (i + 15 < dim) { + float16x8_t x0 = MlasLoadFloat16x8(input + i); + float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + for (; i + 31 < dim; i += 16) { + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + x0 = MlasLoadFloat16x8(input + i + 16); + x1 = MlasLoadFloat16x8(input + i + 24); + sin_val = MlasLoadFloat16x8(sin + i + 16); + cos_val = MlasLoadFloat16x8(cos + i + 16); + } + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + i += 16; + } + for (; i + 7 < dim; i += 8) { + float16x4_t x0 = MlasLoadFloat16x4(input + i); + float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); + float16x4_t real = vuzp1_f16(x0, x1); + float16x4_t imag = vuzp2_f16(x0, x1); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + float16x4_t y0 = vzip1_f16(real_out, imag_out); + float16x4_t y1 = vzip2_f16(real_out, imag_out); + MlasStoreFloat16x4(output + i, y0); + MlasStoreFloat16x4(output + i + 4, y1); + } + if (dim - i == 6) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); + imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + MlasStoreLaneFloat16x4<2>(output + i + 4, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out); + } else if (dim - i == 4) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + } else if (dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + } +} + +} // namespace + +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output); + + if (interleaved) { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} // namespace rope_neon diff --git a/src/lib/s390x/DgemmKernel.cpp b/src/lib/s390x/DgemmKernel.cpp new file mode 100644 index 0000000..f7b7f7c --- /dev/null +++ b/src/lib/s390x/DgemmKernel.cpp @@ -0,0 +1,87 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernel.cpp + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + +--*/ +#include "DgemmKernelZVECTOR.h" + +size_t +MLASCALL +MlasDgemmKernel( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + double alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasDgemmCopyPackB or MlasDgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see DGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + + MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha); + + if (CountM >= 4) { + RowsHandled = MlasDgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/DgemmKernelZVECTOR.h b/src/lib/s390x/DgemmKernelZVECTOR.h new file mode 100644 index 0000000..48dc7de --- /dev/null +++ b/src/lib/s390x/DgemmKernelZVECTOR.h @@ -0,0 +1,122 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + +--*/ + +#include "FgemmKernelZVECTOR.h" + +template +MLAS_FORCEINLINE +size_t +MlasDgemmProcessCount( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT64X2 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const double* a = A; + size_t k = CountK; + + MLAS_FLOAT64X2 Accumulators[RowCount][4]; + MLAS_FLOAT64X2 AElements[RowCount]; + MLAS_FLOAT64X2 ABroadcast[RowCount]; + + // + // Clear the block accumulators. + // + + MlasLoopUnroll()(Accumulators); + + // + // Compute the output block. + // + while (k >= 2) { + + MlasLoopUnroll()(AElements, a, lda); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 8); + + a += 2; + B += 8 * 2; + k -= 2; + } + if (k > 0) { + + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + a += 1; + B += 8; + k -= 1; + } + + if (CountN >= 8) { + + // + // Store the entire output block. + // + + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + + } else { + + // + // Store the partial output block. + // + + // + if (CountN >= 6) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 4) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } + // + // Store the remaining unaligned columns. + // + C += (CountN & ~1); + CountN &= 1; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators, AlphaBroadcast); + + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + + break; + } + + C += 8; + CountN -= 8; + + } while (CountN > 0); + + return RowCount; +} + diff --git a/src/lib/s390x/FgemmKernelZVECTOR.h b/src/lib/s390x/FgemmKernelZVECTOR.h new file mode 100644 index 0000000..af77974 --- /dev/null +++ b/src/lib/s390x/FgemmKernelZVECTOR.h @@ -0,0 +1,333 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the single/double precision matrix/matrix + multiply operation (DGEMM/SGEMM). + +--*/ + +#include "mlasi.h" +#if defined(SINGLE) +#define MLAS_FLOATTYPE MLAS_FLOAT32X4 +#define MLAS_GEMMTYPE float +#define MLAS_LOAD_FLOAT MlasLoadFloat32x4 +#define MLAS_ZERO_FLOAT MlasZeroFloat32x4 +#define MLAS_STORE_FLOAT MlasStoreFloat32x4 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4 +#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4 +#else +#define MLAS_FLOATTYPE MLAS_FLOAT64X2 +#define MLAS_GEMMTYPE double +#define MLAS_LOAD_FLOAT MlasLoadFloat64x2 +#define MLAS_ZERO_FLOAT MlasZeroFloat64x2 +#define MLAS_STORE_FLOAT MlasStoreFloat64x2 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2 +#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2 +#endif +// +// Templates to ensure that a loop is unrolled. +// + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... Arguments + ) + { + IterationType::template Iteration(Arguments...); + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... + ) + { + // Terminate the loop. + } +}; + +template +struct MlasLoopUnroll +{ + template + MLAS_FORCEINLINE + void + operator()( + IterationArgs&&... Arguments + ) + { + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +// +// Templates used with loop unrolling to perform an action on one row of the +// output. +// + +struct MlasFgemmZeroAccumulators +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4] + ) + { + Accumulators[Row][0] = MLAS_ZERO_FLOAT(); + Accumulators[Row][1] = MLAS_ZERO_FLOAT(); + Accumulators[Row][2] = MLAS_ZERO_FLOAT(); + Accumulators[Row][3] = MLAS_ZERO_FLOAT(); + } +}; + +struct MlasFgemmLoadAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda); + } +}; + +struct MlasFgemmBroadcastAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda); + } +}; + +template +struct MlasFgemmSplatAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + MLAS_FLOATTYPE ABroadcast[RowCount] + ) + { + ABroadcast[Row] = vec_splat(AElements[Row], Lane); + } +}; + +struct MlasFgemmMultiplyAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + MLAS_FLOATTYPE BElements[4] + ) + { + Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]); + Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]); + Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]); + Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]); + } +}; + +template +MLAS_FORCEINLINE +void +MlasFgemmComputeBlock( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* B + ) +{ + MLAS_FLOATTYPE BElements[4]; +#if defined(SINGLE) + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 4); + BElements[2] = MLAS_LOAD_FLOAT(B + 8); + BElements[3] = MLAS_LOAD_FLOAT(B + 12); +#else + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 2); + BElements[2] = MLAS_LOAD_FLOAT(B + 4); + BElements[3] = MLAS_LOAD_FLOAT(B + 6); +#endif + + MlasLoopUnroll()(Accumulators, ABroadcast, BElements); +} + +struct MlasFgemmMultiplyAlphaRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast); + } +}; + +struct MlasFgemmMultiplyAlphaAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast, + const MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4)); +#else + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2)); +#endif + } +}; + +struct MlasFgemmStoreRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]); +#else + MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]); +#endif + } +}; + +template +struct MlasFgemmStoreVector +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + MLAS_FLOATTYPE AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc; + + if (ZeroMode) { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); + } else { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); + } + + MlasLoopUnroll()(Accumulators[Row], c); + + // + // Shift down any unaligned elements to the bottom for further processing. + // + + if (VectorCount < 4) { + Accumulators[Row][0] = Accumulators[Row][VectorCount]; + } + } +}; + +struct MlasFgemmMultiplyAlphaTrailing +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast); + } +}; + +template +struct MlasFgemmStoreScalar +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc + Lane; + MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT(Accumulators[Row][0]); + + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + diff --git a/src/lib/s390x/Quantize.cpp b/src/lib/s390x/Quantize.cpp new file mode 100644 index 0000000..6bb4475 --- /dev/null +++ b/src/lib/s390x/Quantize.cpp @@ -0,0 +1,300 @@ +#include +#include "mlasi.h" +#include + +// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. +// ONNX Runtime CI pipelines do not build with all compiler versions. + +template +void +MLASCALL +MlasQuantizeLinearKernel( + const float* Input, + OutputType* Output, + size_t N, + float Scale, + OutputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = std::numeric_limits::lowest(); + constexpr int32_t MaximumValue = std::numeric_limits::max(); + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = FloatVector0 / ScaleVector; + FloatVector1 = FloatVector1 / ScaleVector; + FloatVector2 = FloatVector2 / ScaleVector; + FloatVector3 = FloatVector3 / ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = FloatVector0 + ZeroPointVector; + FloatVector1 = FloatVector1 + ZeroPointVector; + FloatVector2 = FloatVector2 + ZeroPointVector; + FloatVector3 = FloatVector3 + ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + if constexpr (std::is_same_v || std::is_same_v) { + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)Output); + } else { + static_assert(std::is_same_v || std::is_same_v); + vec_xst(ShortVector0, 0, (int16_t *)Output); + vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); + } + + Output += 16; + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); + FloatValue = std::max(FloatValue, float(MinimumValue)); + FloatValue = std::min(FloatValue, float(MaximumValue)); + Output[n] = (OutputType)(int32_t)FloatValue; + } +} + +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer as int4 using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. Contains packed 4-bit elements. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. + UnpackedType TmpOutput[16] = {}; + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = FloatVector0 / ScaleVector; + FloatVector1 = FloatVector1 / ScaleVector; + FloatVector2 = FloatVector2 / ScaleVector; + FloatVector3 = FloatVector3 / ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = FloatVector0 + ZeroPointVector; + FloatVector1 = FloatVector1 + ZeroPointVector; + FloatVector2 = FloatVector2 + ZeroPointVector; + FloatVector3 = FloatVector3 + ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); + + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); + MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); + MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); + MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); + MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); + MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); + + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + +void +MLASCALL +MlasQuantizeLinearU8Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS8Kernel( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} diff --git a/src/lib/s390x/QuantizeZVECTOR.cpp b/src/lib/s390x/QuantizeZVECTOR.cpp new file mode 100644 index 0000000..d6d86d0 --- /dev/null +++ b/src/lib/s390x/QuantizeZVECTOR.cpp @@ -0,0 +1,149 @@ +#include "mlasi.h" +#include + +template +void +MLASCALL +MlasQuantizeLinearZVECTOR( + const float* Input, + OutputType* Output, + size_t N, + float Scale, + OutputType ZeroPoint + ) +{ + // Workaround for bad GCC warning that Scale is set but not used. + MLAS_UNREFERENCED_PARAMETER(Scale); + + constexpr int32_t MinimumValue = std::numeric_limits::min(); + constexpr int32_t MaximumValue = std::numeric_limits::max(); + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 /= ScaleVector; + FloatVector1 /= ScaleVector; + FloatVector2 /= ScaleVector; + FloatVector3 /= ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 += ZeroPointVector; + FloatVector1 += ZeroPointVector; + FloatVector2 += ZeroPointVector; + FloatVector3 += ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *) Output); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + + Output += 16; + Input += 16; + N -= 16; + } + + while (N >= 4) { + auto FloatVector = vec_xl(0, Input); + FloatVector /= ScaleVector; + FloatVector = vec_round(FloatVector); + FloatVector += ZeroPointVector; + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_signed(FloatVector); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + OutputType tmp_output[sizeof(__vector float)/sizeof(OutputType)]; + vec_xst(CharVector, 0, (int8_t *) tmp_output); + memcpy(Output, tmp_output, N); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + + Output += 4; + Input += 4; + N -= 4; + } + + if (N > 0) { + float tmp_input[sizeof(__vector float) / sizeof(float)] = {}; + memcpy(tmp_input, Input, 4*N); + auto FloatVector = vec_xl(0, &(tmp_input[0])); + + FloatVector /= ScaleVector; + FloatVector = vec_round(FloatVector); + FloatVector += ZeroPointVector; + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_signed(FloatVector); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + OutputType tmp_output[sizeof(__vector float)/sizeof(OutputType)]; + vec_xst(CharVector, 0, (int8_t *) tmp_output); + memcpy(Output, tmp_output, N); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + } +} + +void +MLASCALL +MlasQuantizeLinearU8KernelZVECTOR( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasQuantizeLinearZVECTOR(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS8KernelZVECTOR( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearZVECTOR(Input, Output, N, Scale, ZeroPoint); +} diff --git a/src/lib/s390x/SgemmKernel.cpp b/src/lib/s390x/SgemmKernel.cpp new file mode 100644 index 0000000..e7d8cde --- /dev/null +++ b/src/lib/s390x/SgemmKernel.cpp @@ -0,0 +1,87 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernel.cpp + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ +#include "SgemmKernelZVECTOR.h" + +size_t +MLASCALL +MlasSgemmKernel( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + + MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); + + if (CountM >= 4) { + RowsHandled = MlasSgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/SgemmKernelZVECTOR.cpp b/src/lib/s390x/SgemmKernelZVECTOR.cpp new file mode 100644 index 0000000..e87913a --- /dev/null +++ b/src/lib/s390x/SgemmKernelZVECTOR.cpp @@ -0,0 +1,451 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelZVECTOR.cpp + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ + +#include "SgemmKernelZVECTOR.h" + +#include + +struct MlasSgemmBroadcastAElementsZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 ABroadcast[RowCount], + const float* A, + size_t lda + ) + { + ABroadcast[0][Row] = A [Row * lda]; + } +}; + +template +MLAS_FORCEINLINE +void +MlasSgemmComputeAElements( + MLAS_FLOAT32X4 AElements[RowCount], + MLAS_FLOAT32X4 ABroadcast[RowCount] + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + const __vector unsigned char mask_even = { 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27 }; + const __vector unsigned char mask_odd = { 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31 }; + + __vector float a1,a2; + + a1 = vec_perm(AElements[0], AElements[1], mask_even); + a2 = vec_perm(AElements[2], AElements[3], mask_even); + ABroadcast[0] = vec_perm(a1, a2, mask0); + ABroadcast[2] = vec_perm(a1, a2, mask3); + a1 = vec_perm(AElements[0], AElements[1], mask_odd); + a2 = vec_perm(AElements[2], AElements[3], mask_odd); + ABroadcast[1] = vec_perm(a1, a2, mask0); + ABroadcast[3] = vec_perm(a1, a2, mask3); +} +template +MLAS_FORCEINLINE +void +MlasSgemmComputeBlockZVECTOR( + MLAS_FLOAT32X4 acc[32], + MLAS_FLOAT32X4 ABroadcast, + MLAS_FLOAT32X4 A2Broadcast, + const float* B, + size_t CountM + ) +{ + + MLAS_FLOAT32X4 AElements[8]; + + AElements[0] = vec_splats(ABroadcast[0]); + AElements[1] = vec_splats(ABroadcast[1]); + AElements[2] = vec_splats(ABroadcast[2]); + AElements[3] = vec_splats(ABroadcast[3]); + + if (CountM == 8) { + AElements[4] = vec_splats(A2Broadcast[0]); + AElements[5] = vec_splats(A2Broadcast[1]); + AElements[6] = vec_splats(A2Broadcast[2]); + AElements[7] = vec_splats(A2Broadcast[3]); + } + + MLAS_FLOAT32X4 BElements[4]; + + BElements[0] = MlasLoadFloat32x4(B); + BElements[1] = MlasLoadFloat32x4(B + 4); + BElements[2] = MlasLoadFloat32x4(B + 8); + BElements[3] = MlasLoadFloat32x4(B + 12); + + acc[0] = __builtin_s390_vfmasb(AElements[0], BElements[0], acc[0]); + acc[1] = __builtin_s390_vfmasb(AElements[1], BElements[0], acc[1]); + acc[2] = __builtin_s390_vfmasb(AElements[2], BElements[0], acc[2]); + acc[3] = __builtin_s390_vfmasb(AElements[3], BElements[0], acc[3]); + + acc[4] = __builtin_s390_vfmasb(AElements[0], BElements[1], acc[4]); + acc[5] = __builtin_s390_vfmasb(AElements[1], BElements[1], acc[5]); + acc[6] = __builtin_s390_vfmasb(AElements[2], BElements[1], acc[6]); + acc[7] = __builtin_s390_vfmasb(AElements[3], BElements[1], acc[7]); + + acc[8] = __builtin_s390_vfmasb(AElements[0], BElements[2], acc[8]); + acc[9] = __builtin_s390_vfmasb(AElements[1], BElements[2], acc[9]); + acc[10] = __builtin_s390_vfmasb(AElements[2], BElements[2], acc[10]); + acc[11] = __builtin_s390_vfmasb(AElements[3], BElements[2], acc[11]); + + acc[12] = __builtin_s390_vfmasb(AElements[0], BElements[3], acc[12]); + acc[13] = __builtin_s390_vfmasb(AElements[1], BElements[3], acc[13]); + acc[14] = __builtin_s390_vfmasb(AElements[2], BElements[3], acc[14]); + acc[15] = __builtin_s390_vfmasb(AElements[3], BElements[3], acc[15]); + + if (CountM == 8) { + acc[16] = __builtin_s390_vfmasb(AElements[4], BElements[0], acc[16]); + acc[17] = __builtin_s390_vfmasb(AElements[5], BElements[0], acc[17]); + acc[18] = __builtin_s390_vfmasb(AElements[6], BElements[0], acc[18]); + acc[19] = __builtin_s390_vfmasb(AElements[7], BElements[0], acc[19]); + + acc[20] = __builtin_s390_vfmasb(AElements[4], BElements[1], acc[20]); + acc[21] = __builtin_s390_vfmasb(AElements[5], BElements[1], acc[21]); + acc[22] = __builtin_s390_vfmasb(AElements[6], BElements[1], acc[22]); + acc[23] = __builtin_s390_vfmasb(AElements[7], BElements[1], acc[23]); + + acc[24] = __builtin_s390_vfmasb(AElements[4], BElements[2], acc[24]); + acc[25] = __builtin_s390_vfmasb(AElements[5], BElements[2], acc[25]); + acc[26] = __builtin_s390_vfmasb(AElements[6], BElements[2], acc[26]); + acc[27] = __builtin_s390_vfmasb(AElements[7], BElements[2], acc[27]); + + acc[28] = __builtin_s390_vfmasb(AElements[4], BElements[3], acc[28]); + acc[29] = __builtin_s390_vfmasb(AElements[5], BElements[3], acc[29]); + acc[30] = __builtin_s390_vfmasb(AElements[6], BElements[3], acc[30]); + acc[31] = __builtin_s390_vfmasb(AElements[7], BElements[3], acc[31]); + } +} +template +struct MlasSgemmStoreVectorZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Result[4], + float* C, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_FLOAT32X4 *rowC; + if (ZeroMode) { + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); + rowC[0] = Result[Row] * AlphaBroadcast; + } else { + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); + rowC[0] += Result[Row] * AlphaBroadcast; + } + } +}; + +struct MlasSgemmMultiplyAlphaTrailingZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount], + MLAS_FLOAT32X4 AlphaBroadcast + ) + { + Accumulators[Row] = MlasMultiplyFloat32x4(Accumulators[Row], AlphaBroadcast); + } +}; +template +struct MlasSgemmStoreScalarZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount], + float* C, + size_t ldc, + bool ZeroMode + ) + { + float* c = C + Row * ldc + Lane; + float Value = Accumulators[Row][Lane]; + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + +template +MLAS_FORCEINLINE +size_t +MlasSgemmZVECTORProcessCount( + const float* A, + const float* B, + float* C, + size_t CountM, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const float* a = A; + size_t k = CountK; + + MLAS_FLOAT32X4 AElements[RowCount]; + MLAS_FLOAT32X4 ABroadcast[RowCount] = { 0 }; + MLAS_FLOAT32X4 A2Broadcast[RowCount] = { 0 }; + MLAS_FLOAT32X4 acc[32] = { 0 }; + MLAS_FLOAT32X4 Accumulators[2][RowCount] = {{0}}; + + // + // Compute the output block. + // + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + MlasSgemmComputeAElements(AElements, ABroadcast); + if (CountM == 8) { + MlasLoopUnroll()(AElements, a + ( lda * 4), lda); + MlasSgemmComputeAElements(AElements, A2Broadcast); + } + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[1], A2Broadcast[1], B+16, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[2], A2Broadcast[2], B+32, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[3], A2Broadcast[3], B+48, CountM); + B += 16 * 4; + a += 4; + k -= 4; + } + + while (k > 0) { + MlasLoopUnroll()(ABroadcast, a, lda); + if (CountM == 8) { + MlasLoopUnroll()(A2Broadcast, a + (lda * 4), lda); + } + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); + a += 1; + B += 16; + k -= 1; + } + if (CountN >= 16) { + + // + // Store the entire output block. + // + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 8, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 12, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 24, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 28, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + } + } else { + + // + // Store the partial output block. + // + + if (CountN >= 12) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 8, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 24, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 28]; + } + } + } + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 12]; + } + } + } else if (CountN >= 8) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 24]; + } + } + } + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 8]; + } + } + } else if (CountN >= 4) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 20]; + } + } + } + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 4]; + } + } + } else { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i]; + } + + if (CountM == 8) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 16]; + } + } + } + + // + // Store the remaining unaligned columns. + // + + C += (CountN & ~3); + CountN &= 3; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators[0], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll()(Accumulators[1], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + } + if (CountN >= 3) { + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + } + } + + break; + } + + C += 16; + CountN -= 16; + + } while (CountN > 0); + + return CountM; +} + +size_t +MLASCALL +MlasSgemmKernelZVECTOR( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); + + if (CountM >= 8) { + RowsHandled = MlasSgemmZVECTORProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 4) { + RowsHandled = MlasSgemmZVECTORProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/SgemmKernelZVECTOR.h b/src/lib/s390x/SgemmKernelZVECTOR.h new file mode 100644 index 0000000..83c66eb --- /dev/null +++ b/src/lib/s390x/SgemmKernelZVECTOR.h @@ -0,0 +1,139 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ + +#include "FgemmKernelZVECTOR.h" + +template +MLAS_FORCEINLINE +size_t +MlasSgemmProcessCount( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const float* a = A; + size_t k = CountK; + + MLAS_FLOAT32X4 Accumulators[RowCount][4]; + MLAS_FLOAT32X4 AElements[RowCount]; + MLAS_FLOAT32X4 ABroadcast[RowCount]; + + // + // Clear the block accumulators. + // + + MlasLoopUnroll()(Accumulators); + + // + // Compute the output block. + // + + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 16); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 32); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 48); + + a += 4; + B += 16 * 4; + k -= 4; + } + + while (k > 0) { + + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + a += 1; + B += 16; + k -= 1; + } + + if (CountN >= 16) { + + // + // Store the entire output block. + // + + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + + } else { + + // + // Store the partial output block. + // + + if (CountN >= 12) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 8) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 4) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } + + // + // Store the remaining unaligned columns. + // + + C += (CountN & ~3); + CountN &= 3; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators, AlphaBroadcast); + + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + + if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + + if (CountN >= 3) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + } + + break; + } + + C += 16; + CountN -= 16; + + } while (CountN > 0); + + return RowCount; +} + diff --git a/src/lib/s390x/qgemm_kernel_zvector.cpp b/src/lib/s390x/qgemm_kernel_zvector.cpp new file mode 100644 index 0000000..92aa72b --- /dev/null +++ b/src/lib/s390x/qgemm_kernel_zvector.cpp @@ -0,0 +1,1409 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_zvector.cpp + +Abstract: + + This module implements QGEMM kernel for S390X. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" +#include + +struct MLAS_GEMM_QUANT_KERNEL_ZVECTOR +{ + typedef int8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef uint8_t OffsetBType; + static constexpr size_t PackedK = 4; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 16, 256, 384 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 16, 128, 128 }; +}; + +constexpr size_t MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_ZVECTOR::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedStrides; + +using vector_int = __attribute__((vector_size(16))) int; + +template +MLAS_FORCEINLINE +static +vector_int vec_sum4s_impl(Vtype value) +{ + const __vector unsigned char mask_interleave = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + + __vector signed char signed_value = (__vector signed char) vec_perm(value, value, mask_interleave); + + auto tmp1 = vec_unpackh(vec_unpackh(signed_value)); + auto tmp2 = vec_unpackl(vec_unpackh(signed_value)); + auto tmp3 = vec_unpackh(vec_unpackl(signed_value)); + auto tmp4 = vec_unpackl(vec_unpackl(signed_value)); + + return (__vector int) (tmp1 + tmp2 + tmp3 + tmp4); +} + +#define INC_BUFFER(cnt) \ + ColumnSumBuffer += cnt; \ + if (ZeroPointB != nullptr) { \ + ZeroPointB += cnt; \ + } +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointA( + int32_t ZeroPointA, + bool AIsSigned + ) +{ + if (!AIsSigned) { + ZeroPointA = MLAS_GEMM_QUANT_KERNEL_ZVECTOR::OffsetAType(ZeroPointA ^ 0x80); + } + return ZeroPointA; +} + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_QUANT_KERNEL_ZVECTOR::OffsetBType(ZeroPointB ^ 0x80); + } + return ZeroPointB; + +} + +template +void +MlasGemmQuantCopyPackA8x8( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80); + Vtype vmask = reinterpret_cast(vec_splats(Flip)); + + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + const __vector unsigned char mask_even = { 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27 }; + const __vector unsigned char mask_odd = { 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31 }; + + // Process eight rows of matrix A in a loop. + // + // The buffer is packed as a series of 4x4 byte vectors to help + // in getting into MMA loop. + // + // Unsigned buffers are converted to signed buffers in order to + // share a common kernel. + // This pattern is repeated (CountK / 4) times. + // + // If CountK is not aligned to a multiple of four, then the vector is padded + // with zeroes. + // + while (CountM >= 8) { + const uint8_t *a = A; + __vector int vsum = { 0 }; + __vector int vsum2 = { 0 }; + size_t y = CountK; + while (y >= 16) { + Vtype a1 = *reinterpret_cast(&a[0]); + Vtype a2 = *reinterpret_cast(&a[lda]); + Vtype a3 = *reinterpret_cast(&a[lda * 2]); + Vtype a4 = *reinterpret_cast(&a[lda * 3]); + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + a1 = *reinterpret_cast(&a[lda*4]); + a2 = *reinterpret_cast(&a[lda*5]); + a3 = *reinterpret_cast(&a[lda*6]); + a4 = *reinterpret_cast(&a[lda*7]); + vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx8 = vec_perm(vx, vx1, mask0); + Vtype vx9 = vec_perm(vx2, vx3, mask0); + Vtype vx10 = vec_perm(vx, vx1, mask3); + Vtype vx11 = vec_perm(vx2, vx3, mask3); + Vtype vxx = AIsSigned ? vx4 : vx4 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[0]) = vxx; + vxx = AIsSigned ? vx5 : vx5 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[16]) = vxx; + vxx = AIsSigned ? vx6 : vx6 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[32]) = vxx; + vxx = AIsSigned ? vx7 : vx7 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[48]) = vxx; + vxx = AIsSigned ? vx8 : vx8 - vmask; + *reinterpret_cast(&D[64]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx9 : vx9 - vmask; + *reinterpret_cast(&D[80]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx10 : vx10 - vmask; + *reinterpret_cast(&D[96]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx11 : vx11 - vmask; + *reinterpret_cast(&D[112]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + D += 16 * 8; + a += 16; + y -= 16; + } + size_t yval = y; + while (y >= 4) + { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda]); + int a3 = *reinterpret_cast(&a[lda*2]); + int a4 = *reinterpret_cast(&a[lda*3]); + __vector int vx1 = { a1, a2, a3, a4}; + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + vsum += vec_sum4s_impl(vx); + *reinterpret_cast(&D[0]) = vx; + a1 = *reinterpret_cast(&a[lda*4]); + a2 = *reinterpret_cast(&a[lda*5]); + a3 = *reinterpret_cast(&a[lda*6]); + a4 = *reinterpret_cast(&a[lda*7]); + __vector int vx2 = { a1, a2, a3, a4}; + vx = AIsSigned ? reinterpret_cast(vx2) : reinterpret_cast(vx2) - vmask; + vsum2 += vec_sum4s_impl(vx); + if (CountK & 3) { + if (yval >= 12) { + *reinterpret_cast(&D[64]) = vx; + } else if (yval >= 8) { + *reinterpret_cast(&D[48]) = vx; + } else { + *reinterpret_cast(&D[32]) = vx; + } + } else { + if (yval >= 12) { + *reinterpret_cast(&D[48]) = vx; + } else if (yval >= 8) { + *reinterpret_cast(&D[32]) = vx; + } else { + *reinterpret_cast(&D[16]) = vx; + } + } + D += 16; + a += 4; + y -= 4; + } + if (yval >= 12) { + if (!(CountK & 3)) { + D += 48; + } + } else if (yval >= 8) { + if (!(CountK & 3)) { + D += 32; + } + } else if (yval >= 4) { + if (!(CountK & 3)) { + D += 16; + } + } + if (y >= 1) + { + Vtype a1 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a4 = vmask; + a1[0] = a[0]; + a2[0] = a[lda]; + a3[0] = a[lda * 2]; + a4[0] = a[lda * 3]; + if (y >= 2) { + a1[1] = a[1]; + a2[1] = a[lda + 1]; + a3[1] = a[lda * 2 + 1]; + a4[1] = a[lda * 3 + 1]; + } + if (y >= 3) { + a1[2] = a[2]; + a2[2] = a[lda + 2]; + a3[2] = a[lda * 2 + 2]; + a4[2] = a[lda * 3 + 2]; + } + Vtype vx = reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = vec_perm(vx, vx1, mask0); + Vtype vx3 = AIsSigned ? vx2 : vx2 - vmask; + vsum += vec_sum4s_impl(vx3); + + *reinterpret_cast(&D[0]) = vx3; + a1 = vmask; + a2 = vmask; + a3 = vmask; + a4 = vmask; + a1[0] = a[lda * 4]; + a2[0] = a[lda * 5]; + a3[0] = a[lda * 6]; + a4[0] = a[lda * 7]; + if (y >= 2) { + a1[1] = a[lda * 4 + 1]; + a2[1] = a[lda * 5 + 1]; + a3[1] = a[lda * 6 + 1]; + a4[1] = a[lda * 7 + 1]; + } + if (y >= 3) { + a1[2] = a[lda * 4 + 2]; + a2[2] = a[lda * 5 + 2]; + a3[2] = a[lda * 6 + 2]; + a4[2] = a[lda * 7 + 2]; + } + vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + vx2 = vec_perm(vx, vx1, mask0); + vx3 = AIsSigned ? vx2 : vx2 - vmask; + vsum2 += vec_sum4s_impl(vx3); + if (CountK % 16 >= 12) { + *reinterpret_cast(&D[64]) = vx3; + D += 80; + } else if (CountK % 16 >= 8) { + *reinterpret_cast(&D[48]) = vx3; + D += 64; + } else if (CountK % 16 >= 4) { + *reinterpret_cast(&D[32]) = vx3; + D += 48; + } else { + *reinterpret_cast(&D[16]) = vx3; + D += 16 * 2; + } + a += 16; + } + A += lda * 8; + + vec_xst(vsum, 0, &(RowSumBuffer[0])); + vec_xst(vsum2, 16, &(RowSumBuffer[0])); + + RowSumBuffer += 8; + CountM -= 8; + } + + // Process four rows of matrix A in a loop. + // + if (CountM >= 4) + { + const uint8_t *a = A; + __vector int vsum = { 0 }; + size_t y = CountK; + + while (y >= 16) + { + Vtype a1 = *reinterpret_cast(&a[0]); + Vtype a2 = *reinterpret_cast(&a[lda]); + Vtype a3 = *reinterpret_cast(&a[lda * 2]); + Vtype a4 = *reinterpret_cast(&a[lda * 3]); + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + Vtype vx0 = AIsSigned ? vx4 : vx4 - vmask; + *reinterpret_cast(&D[0]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx5 : vx5 - vmask; + *reinterpret_cast(&D[16]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx6 : vx6 - vmask; + *reinterpret_cast(&D[32]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx7 : vx7 - vmask; + *reinterpret_cast(&D[48]) = vx0; + vsum += vec_sum4s_impl(vx0); + D += 16 * 4; + a += 16; + y -= 16; + } + while (y >= 4) + { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda]); + int a3 = *reinterpret_cast(&a[lda*2]); + int a4 = *reinterpret_cast(&a[lda*3]); + __vector int vx1 = { a1, a2, a3, a4}; + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + a += 4; + y -= 4; + } + if (y >= 1) + { + Vtype vx = vmask; + vx[0] = a[0]; + vx[4] = a[lda]; + vx[8] = a[lda * 2]; + vx[12] = a[lda * 3]; + if (y >= 2) { + vx[1] = a[1]; + vx[5] = a[lda + 1]; + vx[9] = a[lda * 2 + 1]; + vx[13] = a[lda * 3 + 1]; + } + if (y >= 3) { + vx[2] = a[2]; + vx[6] = a[lda + 2]; + vx[10] = a[lda * 2 + 2]; + vx[14] = a[lda * 3 + 2]; + } + Vtype vx1 = AIsSigned ? vx : vx - vmask; + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4s_impl(vx1); + D += 16; + a += 16; + } + A += lda * 4; + + vec_xst(vsum, 0, &(RowSumBuffer[0])); + + RowSumBuffer += 4; + CountM -= 4; + } + + // Process remaining rows of matrix A in a loop. + // + if (CountM <= 3 && CountM > 0) { + const uint8_t *a = A; + size_t y = CountK; + __vector int vsum = { 0 }; + + while (y >= 16) { + Vtype a4 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a1 = *reinterpret_cast(&a[0]); + if (CountM == 3) { + a3 = *reinterpret_cast(&a[lda * 2]); + } + if (CountM >= 2) { + a2 = *reinterpret_cast(&a[lda]); + } + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + Vtype vx0 = AIsSigned ? vx4 : vx4 - vmask; + *reinterpret_cast(&D[0]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx5 : vx5 - vmask; + *reinterpret_cast(&D[16]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx6 : vx6 - vmask; + *reinterpret_cast(&D[32]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx7 : vx7 - vmask; + *reinterpret_cast(&D[48]) = vx0; + vsum += vec_sum4s_impl(vx0); + D += 16 * 4; + a += 16; + y -= 16; + } + while (y >= 4) + { + Vtype vb = vmask; + __vector int vx1 = reinterpret_cast<__vector int>(vb); + vx1[0] = *reinterpret_cast(&a[0]); + if (CountM >= 2) { + vx1[1] = *reinterpret_cast(&a[lda]); + } + if (CountM >= 3) { + vx1[2] = *reinterpret_cast(&a[lda*2]); + } + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + a += 4; + y -= 4; + } + if (y >= 1) + { + Vtype vx = (Vtype) vec_splats(0); + vx[0] = a[0] ^ Flip; + if (y >= 2) { + vx[1] = a[1] ^ Flip; + } + if (y >= 3) { + vx[2] = a[2] ^ Flip; + } + if (CountM >= 2) { + vx[4] = a[lda] ^ Flip; + if (y >= 2) { + vx[5] = a[lda + 1] ^ Flip; + } + if (y >= 3) { + vx[6] = a[lda + 2] ^ Flip; + } + } + if (CountM == 3) { + vx[8] = a[lda * 2] ^ Flip; + if (y >= 2) { + vx[9] = a[lda * 2 + 1] ^ Flip; + } + if (y >= 3) { + vx[10] = a[lda * 2 + 2] ^ Flip; + } + } + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + } + *RowSumBuffer++ = vsum[0]; + if (CountM >= 2) { + *RowSumBuffer++ = vsum[1]; + } + if (CountM >= 3) { + *RowSumBuffer++ = vsum[2]; + } + } +} + +template +void +MlasGemmQuantCopyPackB8x8( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer + ) +{ + [[maybe_unused]] constexpr uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); + typedef __vector unsigned char vec_t; + Vtype vmask = reinterpret_cast(vec_splats(BitFlipValue)); + vec_t mask = {0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15}; + + const __vector unsigned char vec_zero = { 0 }; + + // Copy columns from matrix B to the packed buffer. Signed buffers are + // converted to unsigned buffers in order to share a common kernel. + // + // If CountK is not aligned to a multiple of four, then the packed buffer + // is padded with zero vectors. + + // Process 16 columns of matrix B in a loop. + // + size_t PackedK = ((CountK + 4 - 1) / 4) * 16; + size_t k2 = PackedK; + size_t k3 = PackedK*2; + size_t k4 = PackedK*3; + + while (CountN >= 16) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + __vector unsigned int vsum2 = {0}; + __vector unsigned int vsum3 = {0}; + __vector unsigned int vsum4 = {0}; + size_t y = CountK; + if (y >= 4) { + do { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = *reinterpret_cast(&b[ldb]); + Vtype b3 = *reinterpret_cast(&b[ldb*2]); + Vtype b4 = *reinterpret_cast(&b[ldb*3]); + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(b1 + vmask) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(b2 + vmask) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(b3 + vmask) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(b4 + vmask) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum += vec_sum4(vx1, vec_zero); + vsum2 += vec_sum4(vx2, vec_zero); + vsum3 += vec_sum4(vx3, vec_zero); + vsum4 += vec_sum4(vx4, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = (y >= 2) ? *reinterpret_cast(&b[ldb]) : vmask; + Vtype b3 = (y >= 3) ? *reinterpret_cast(&b[ldb*2]) : vmask; + Vtype b4 = vmask; + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(b1 + vmask) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(b2 + vmask) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(b3 + vmask) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(b4 + vmask) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum += vec_sum4(vx1, vec_zero); + vsum2 += vec_sum4(vx2, vec_zero); + vsum3 += vec_sum4(vx3, vec_zero); + vsum4 += vec_sum4(vx4, vec_zero); + D += 16; + } + + vec_xst(vsum, 0, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum2, 16, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum3, 32, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum4, 48, (unsigned int*) ColumnSumBuffer); + + ColumnSumBuffer += 16; + B += 16; + CountN -= 16; + D += k4; + } + + // Process four columns of matrix B in a loop. + // + while (CountN >= 4) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + size_t y = CountK; + if (y >= 4) { + do { + int b1 = *reinterpret_cast(&b[0]); + int b2 = *reinterpret_cast(&b[ldb]); + int b3 = *reinterpret_cast(&b[ldb*2]); + int b4 = *reinterpret_cast(&b[ldb*3]); + __vector int vb = {b1, b2, b3, b4}; + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype vb = vmask; + __vector int vb1 = reinterpret_cast<__vector int>(vb); + vb1[0] = *reinterpret_cast(&b[0]); + if (y >= 2) { + vb1[1] = *reinterpret_cast(&b[ldb]); + } + if (y >= 3) { + vb1[2] = *reinterpret_cast(&b[ldb*2]); + } + Vtype vx = vec_perm(reinterpret_cast(vb1), reinterpret_cast(vb1), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + } + + vec_xst(vsum, 0, (unsigned int*) ColumnSumBuffer); + + ColumnSumBuffer += 4; + B += 4; + CountN -= 4; + } + + // + // Process the remaining columns of matrix B. + // + if (CountN > 0) { + __vector unsigned int vsum = {0}; + const uint8_t* b = B; + size_t y = CountK; + if (y >= 4) { + do { + Vtype vb = vmask; + if (CountN == 1) { + vb[0] = b[0]; + vb[4] = b[ldb]; + vb[8] = b[ldb*2]; + vb[12] = b[ldb*3]; + } + if (CountN == 2) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[12] = b[ldb*3]; + vb[13] = b[ldb*3+1]; + } + if (CountN == 3) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[2] = b[2]; + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[6] = b[ldb+2]; + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[10] = b[ldb*2+2]; + vb[12] = b[ldb*3]; + vb[13] = b[ldb*3+1]; + vb[14] = b[ldb*3+2]; + } + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype vb = vmask; + if (CountN == 1) { + vb[0]= b[0]; + if (y >= 2) { + vb[4] = b[ldb]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + } + } + if (CountN == 2) { + vb[0] = b[0]; + vb[1] = b[1]; + if (y >= 2) { + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + } + } + if (CountN == 3) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[2] = b[2]; + if (y >= 2) { + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[6] = b[ldb+2]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[10] = b[ldb*2+2]; + } + } + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + } + *ColumnSumBuffer++ = vsum[0]; + if (CountN >= 2) { + *ColumnSumBuffer++ = vsum[1]; + } + if (CountN >= 3) { + *ColumnSumBuffer++ = vsum[2]; + } + } +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + if (AIsSigned) { + MlasGemmQuantCopyPackA8x8<__vector signed char, true> (D, A, lda, CountM, CountK, RowSumBuffer); + } else { + MlasGemmQuantCopyPackA8x8<__vector unsigned char, false>(D, A, lda, CountM, CountK, RowSumBuffer); + } +} +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + if (BIsSigned) { + MlasGemmQuantCopyPackB8x8<__vector signed char, true>(D, B, ldb, CountN, CountK, ColumnSumBuffer); + } else { + MlasGemmQuantCopyPackB8x8< __vector unsigned char, false>(D, B, ldb, CountN, CountK, ColumnSumBuffer); + } +} + +template +MLAS_FORCEINLINE +void +MlasQgemmStoreVectorZVECTOR + ( + MLAS_INT32X4 result[4], + int32_t* C, + size_t ldc, + size_t row, + bool ZeroMode, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + int pos + ) +{ + size_t RowCount; + __vector signed int vsum0, vsum1, vsum2, vsum3; + __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); + C += VectorCount; + if (ZeroPointB != nullptr) { + __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); + if (ZeroMode) { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } else { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } + } else { + if (ZeroMode) { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } else { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } + } +}; +template +MLAS_FORCEINLINE +void +MlasQgemmStoreScalarZVECTOR( + MLAS_INT32X4 result[4], + int32_t* C, + size_t ldc, + size_t row, + bool ZeroMode, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB + ) +{ + if (ZeroPointB != nullptr) { + if (ZeroMode) { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount]; + sum *= ZeroPointB[0]; + sum += ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; + } + } else { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount]; + sum *= ZeroPointB[0]; + sum += ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; + } + } + } else { + if (ZeroMode) { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; + } + } else { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; + } + } + } +}; + +MLAS_FORCEINLINE +void +xvi8ger4pp_impl( + MLAS_INT32X4 acc[4], + __vector unsigned char va, + __vector unsigned char vb + ) +{ + const __vector unsigned char maska[4] = { + { 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2 ,3, 0, 1, 2, 3 }, + { 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7 }, + { 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11 }, + { 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 } + }; + + const __vector unsigned char maskb = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + + __vector int va_interim[4]; + __vector unsigned char vb_interm = vec_perm(vb, vb, maskb); + + __vector int va_prep; + __vector int vb_prep[4]; + + va_interim[0] = (__vector int) vec_unpackh(vec_unpackh((__vector signed char) va)); + va_interim[1] = (__vector int) vec_unpackl(vec_unpackh((__vector signed char) va)); + va_interim[2] = (__vector int) vec_unpackh(vec_unpackl((__vector signed char) va)); + va_interim[3] = (__vector int) vec_unpackl(vec_unpackl((__vector signed char) va)); + + vb_prep[0] = (__vector int) vec_unpackh(vec_unpackh(vb_interm)); + vb_prep[1] = (__vector int) vec_unpackl(vec_unpackh(vb_interm)); + vb_prep[2] = (__vector int) vec_unpackh(vec_unpackl(vb_interm)); + vb_prep[3] = (__vector int) vec_unpackl(vec_unpackl(vb_interm)); + + for (size_t i = 0; i < 4; ++i) + { + for (size_t k = 0; k < 4; ++k) + { + va_prep = vec_perm(va_interim[i], va_interim[i], maska[k]); + + acc[i] += va_prep * vb_prep[k]; + } + } +} + +template +MLAS_FORCEINLINE +void +MlasQgemmComputeZVECTOR( + MLAS_INT32X4 acc0[4], + MLAS_INT32X4 acc1[4], + __vector unsigned char *va, + __vector unsigned char *vb + ) +{ + if (CountK == 16) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + xvi8ger4pp_impl(acc0, va[2], vb[2]); + xvi8ger4pp_impl(acc0, va[3], vb[3]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[4], vb[0]); + xvi8ger4pp_impl(acc1, va[5], vb[1]); + xvi8ger4pp_impl(acc1, va[6], vb[2]); + xvi8ger4pp_impl(acc1, va[7], vb[3]); + } + } else if (CountK == 12) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + xvi8ger4pp_impl(acc0, va[2], vb[2]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[3], vb[0]); + xvi8ger4pp_impl(acc1, va[4], vb[1]); + xvi8ger4pp_impl(acc1, va[5], vb[2]); + } + } else if (CountK == 8) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[2], vb[0]); + xvi8ger4pp_impl(acc1, va[3], vb[1]); + } + } else { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[1], vb[0]); + } + } +}; +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* A, + const MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + if (CountM < 8 && CountM >= 4) { + CountM = 4; + } + size_t Mval = CountM; + if (Mval >= 8) { + Mval = 4; + } + while (CountN > 0) { + const int8_t *a = A; + typedef __vector unsigned char vec_t; + const uint8_t *b = B; + int32_t *C1; + MLAS_INT32X4 acc0[4] = {0}; + MLAS_INT32X4 acc1[4] = {0}; + MLAS_INT32X4 acc2[4] = {0}; + MLAS_INT32X4 acc3[4] = {0}; + MLAS_INT32X4 acc4[4] = {0}; + MLAS_INT32X4 acc5[4] = {0}; + MLAS_INT32X4 acc6[4] = {0}; + MLAS_INT32X4 acc7[4] = {0}; + MLAS_INT32X4 result[4] = {0}; + MLAS_INT32X4 result1[4] = {0}; + size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK; + size_t k1 = PackedCountK; + // + // Compute the output block using POWER10 MMA builtins. + // + while (k >= 16) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + b += 64; + if (CountM >= 8) { + a += 128; + } else { + a += 64; + } + k -= 16; + } + if (k >= 12) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + if (CountM >= 8) { + a += 96; + } else { + a += 48; + } + b += 48; + k -= 12; + } + if (k >= 8) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + if (CountM >= 8) { + a += 64; + } else { + a += 32; + } + b += 32; + k -= 8; + } + if (k >= 4) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + } + // Store matrix C with accumulator result. + if (CountN >= 16) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorZVECTOR<12>(acc3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); + + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc6, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorZVECTOR<12>(acc7, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 12); + } + INC_BUFFER(16); + CountN -= 16; + B += 16 * 4 *PackedCountK; + C += 16; + } else { + if (CountN >=12 ) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc6, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); + } + INC_BUFFER(12); + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc3[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc7[i]; + } + } + } + CountN -= 12; + C += 12; + } else if (CountN >= 8) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + } + INC_BUFFER(8); + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc2[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc6[i]; + } + } + } + CountN -= 8; + C += 8; + } else if (CountN >= 4) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + } + INC_BUFFER(4); + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc1[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc5[i]; + } + } + } + CountN -= 4; + C += 4; + } else { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc0[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc4[i]; + } + } + } + CountN &= 3; + // + // Output the remaining partial output block. + // + if (CountN > 0) { + MlasQgemmStoreScalarZVECTOR<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<0>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + if (CountN >= 2) { + MlasQgemmStoreScalarZVECTOR<1>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<1>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + } + if (CountN >= 3) { + MlasQgemmStoreScalarZVECTOR<2>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<2>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + } + } + CountN = 0; + } + } + if (CountM >= 8) { + return 8; + } + return CountM; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchZVECTOR = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK, + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedStrides.K, + 8 // Kernel M stride +}; diff --git a/src/lib/saturation_check.cpp b/src/lib/saturation_check.cpp new file mode 100644 index 0000000..7b022a7 --- /dev/null +++ b/src/lib/saturation_check.cpp @@ -0,0 +1,42 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include "mlasi.h" + +namespace onnxruntime +{ + +#if defined(MLAS_TARGET_AMD64) + +std::atomic saturation_count{0}; + +void +reset_saturation_count() +{ + saturation_count = 0; +} + +#else + +void +reset_saturation_count() +{ +} + +#endif + +} // namespace onnxruntime diff --git a/src/lib/scalar/SgemmKernelScalar.cpp b/src/lib/scalar/SgemmKernelScalar.cpp index 6272925..cbec5d8 100644 --- a/src/lib/scalar/SgemmKernelScalar.cpp +++ b/src/lib/scalar/SgemmKernelScalar.cpp @@ -83,6 +83,8 @@ Return Value: #endif + int countb = 0; + do { float BElements00; @@ -116,6 +118,7 @@ Return Value: // const float* a = A; + const float* b = B; size_t k = CountK; while (k >= 2) { @@ -128,10 +131,10 @@ Return Value: Row1AElements1 = a[lda + 1]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -144,10 +147,10 @@ Return Value: Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; + BElements00 = b[16]; + BElements01 = b[17]; + BElements02 = b[18]; + BElements03 = b[19]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; @@ -161,7 +164,7 @@ Return Value: } a += 2; - B += 8; + b += 32; k -= 2; } @@ -173,10 +176,10 @@ Return Value: Row1AElements0 = a[lda]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -188,8 +191,6 @@ Return Value: Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - - B += 4; } // @@ -295,9 +296,14 @@ Return Value: break; } + B += 4; C += 4; CountN -= 4; + countb = (countb + 1) % 4; + if (countb == 0) { + B += CountK * 16 - 16; + } } while (CountN > 0); return ProcessTwoRows ? 2 : 1; diff --git a/src/lib/sconv.h b/src/lib/sconv.h new file mode 100644 index 0000000..12ccff2 --- /dev/null +++ b/src/lib/sconv.h @@ -0,0 +1,29 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv.h + +Abstract: + + This module defines convolution kernel flags for configuring convolution + operations including output accumulation, bias addition, and activations. + +--*/ + +// +// Define the convolution kernel flags. +// + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#endif diff --git a/src/lib/sconv_kernel_neon.cpp b/src/lib/sconv_kernel_neon.cpp new file mode 100644 index 0000000..4c5f50a --- /dev/null +++ b/src/lib/sconv_kernel_neon.cpp @@ -0,0 +1,524 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_kernel_neon.cpp + +Abstract: + + This module implements the single precision convolution kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#include "mlasi.h" +#include "sconv.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// Common implementation for NCHW and NCHWC convolution kernels +template +void + MLASCALL + MlasConvFloatKernelNeonImpl( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + MLAS_UNREFERENCED_PARAMETER(InputStride); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + for (size_t filterSetBlock = 0; filterSetBlock < FilterCount; filterSetBlock++) { + const float* filter = Filter + filterSetBlock * FilterStrideElements; + float* output = Output + filterSetBlock * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if constexpr (IsNchwcFormat) { + for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) { + const float* input_element = input_base + filterBlock; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) { + input_value = *input_element; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) + + kw * (BlockSize * BlockSize) + + filterBlock * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } else { + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) { + input_value = *input_base; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * KernelWidth + kw; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +void + MLASCALL + MlasConvNchwFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Implementation of MlasConvNchwcFloatKernelNeon +// + +void + MLASCALL + MlasConvNchwcFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Helper function to load input vector with bounds checking +// +static inline float32x4_t +LoadInputVectorWithBounds( + const float* input_base, + size_t offset, + bool is_main_region, + const float* InputBase, + size_t kh, + size_t DilatedInputWidthElements, + size_t InputWidthElements +) +{ + if (is_main_region) { + return MlasLoadFloat32x4(input_base + offset); + } else { + float input_values[4]; + for (size_t i = 0; i < 4; i++) { + const float* input_element = input_base + offset + i; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + if (input_element >= input_row_start && input_element < input_row_end) { + input_values[i] = *input_element; + } else { + input_values[i] = 0.0f; + } + } + return MlasLoadFloat32x4(input_values); + } +} + +// +// Implementation of MlasConvDepthwiseFloatKernelNeon +// +// This kernel performs depthwise separable convolution where each input channel +// is convolved with its own filter. This is more efficient than standard convolution +// for certain network architectures like MobileNets. +// + +void + MLASCALL + MlasConvDepthwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + MLAS_UNREFERENCED_PARAMETER(InputStrideElements); + + const size_t InputWidthElements = InputWidth / sizeof(float); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&Output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(Bias); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(Bias + 4); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(Bias + 8); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(Bias + 12); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + size_t kernel_pos = kh * KernelWidth + kw; + + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + float32x4_t InputVector0 = LoadInputVectorWithBounds(input_base, 0, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector1 = LoadInputVectorWithBounds(input_base, 4, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector2 = LoadInputVectorWithBounds(input_base, 8, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector3 = LoadInputVectorWithBounds(input_base, 12, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector0, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector1, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector2, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector3, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], Accumulator3); + } +} + +// +// Implementation of MlasConvPointwiseFloatKernelNeon +// +// This kernel performs pointwise (1x1) convolution which is essentially +// a matrix multiplication across the channel dimension. It's optimized +// for cases where the kernel size is 1x1. +// + +void + MLASCALL + MlasConvPointwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) { + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t c = 0; c < InputChannels; c++) { + const float* input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements; + + for (size_t input_b = 0; input_b < BlockSize; input_b++) { + const float input_value = input_ptr[input_b]; + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + const float* filter_ptr = filter + (c * BlockSize + input_b) * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(filter_ptr); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(filter_ptr + 4); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(filter_ptr + 8); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(filter_ptr + 12); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +#endif diff --git a/src/lib/sgemm.cpp b/src/lib/sgemm.cpp index 4d7a1ce..84c26eb 100644 --- a/src/lib/sgemm.cpp +++ b/src/lib/sgemm.cpp @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { @@ -1158,6 +1158,7 @@ Return Value: if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { +#if !defined(FORCE_GENERIC_ALGORITHMS) #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; @@ -1181,6 +1182,7 @@ Return Value: } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) } @@ -1193,7 +1195,7 @@ Return Value: if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) && !defined(FORCE_GENERIC_ALGORITHMS) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; @@ -1570,7 +1572,13 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - + // Override + if(GetMlasPlatform().MlasGemmBatchOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ + return; + } // // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. @@ -1578,14 +1586,7 @@ MlasGemmBatch( const double Complexity = double(M) * double(N) * double(K); - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); if (TargetThreadCount >= MaximumThreadCount) { @@ -1642,6 +1643,8 @@ MlasGemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ) @@ -1666,6 +1669,22 @@ Return Value: // // Compute the number of bytes required to hold the packed buffer. // + // KleidiAI or other override + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBSizeOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans) { + size_t bytes_required; + //TODO pass status by reference to indicate success/fail + bytes_required = GetMlasPlatform().MlasGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){// If ArmKleidiAI::MlasGemmPackBSize ran to completion + return bytes_required; + } + } + #endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); @@ -1681,6 +1700,7 @@ Return Value: void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1717,6 +1737,17 @@ Return Value: --*/ { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); diff --git a/src/lib/snchwc.cpp b/src/lib/snchwc.cpp index f9cf160..6f3423a 100644 --- a/src/lib/snchwc.cpp +++ b/src/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/src/lib/softmax.h b/src/lib/softmax.h new file mode 100644 index 0000000..69fe1ae --- /dev/null +++ b/src/lib/softmax.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + softmax. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_SOFTMAX_DISPATCH { + /** + * @brief Compute the hyperbolic tangent function for each element of the input array + * @param Input Address of the input array. Valid in [-3.51562, 3.51562]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + */ + typedef void(Tanh_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ); + + Tanh_Fp16_Fn* Tanh_Fp16 = nullptr; + + /** + * @brief Compute the softcap function for each element of the input array. Use tanh activation. + * @param Input Address of the input array. Valid if input / softcap in [-3.51562, 3.51562]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param Softcap The softcap value + */ + typedef void(Softcap_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 Softcap + ); + + Softcap_Fp16_Fn* Softcap_Fp16 = nullptr; + + /** + * @brief Compute the exponential function for each element of the input array. + * @param Input Address of the input array. Valid in [-17.3287, 11.0904]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + */ + typedef void(Exp_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ); + + Exp_Fp16_Fn* Exp_Fp16 = nullptr; + + /** + * @brief Find the max value among the input array + * @param Input Address of the input array + * @param N Number of elements in the input array + */ + typedef MLAS_FP16(ReduceMax_Fp16_Fn)( + const MLAS_FP16* Input, + size_t N + ); + + ReduceMax_Fp16_Fn* ReduceMax_Fp16 = nullptr; + + /** + * @brief Compute the expotential function for each element of the input array and returnt he sum. It has smaller + * dynamic range for the input than Exp_Fp16_Fn thus is faster. + * @param Input Address of the input array. Valid in [-10.7438, 10.7438] + * @param Output Address of the output array. Could be the same as the input array or nullptr. + * @param N Number of elements in the input array + * @param NegativeMaximum The negative of the maximum value in the input array + */ + typedef MLAS_FP16(SumExp_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 NegativeMaximum + ); + + SumExp_Fp16_Fn* SumExp_Fp16 = nullptr; + + /** + * @brief Compute the softmax output for each element of the input array. input / sum. + * @param Input Address of the input array. Values of exp(x) + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param Sum Sum of exp(input) + */ + typedef void(Softmax_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 Sum + ); + + Softmax_Fp16_Fn* Softmax_Fp16 = nullptr; + + /** + * @brief Compute the log softmax output for each element of the input array. input - max - logSum + * @param Input Address of the input array + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param NagativeMaximum The negative of the maximum value in the input array + * @param LogSum The logarithm of the sum of the exponential function of the input array + */ + typedef void(LogSoftmax_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 NagativeMaximum, + const MLAS_FP16 LogSum + ); + + LogSoftmax_Fp16_Fn* LogSoftmax_Fp16 = nullptr; +}; diff --git a/src/lib/softmax_kernel_neon.cpp b/src/lib/softmax_kernel_neon.cpp new file mode 100644 index 0000000..77ad4b9 --- /dev/null +++ b/src/lib/softmax_kernel_neon.cpp @@ -0,0 +1,38 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon.cpp + +Abstract: + + This module implements the softmax kernels for ARM NEON. + +--*/ + +#include "softmax.h" +#include "softmax_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon = []() { + MLAS_SOFTMAX_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Tanh_Fp16 = softmax_neon::Tanh_Kernel_Fp16; + d.Softcap_Fp16 = softmax_neon::Softcap_Kernel_Fp16; + d.Exp_Fp16 = softmax_neon::Exp_Kernel_Fp16; + d.ReduceMax_Fp16 = softmax_neon::ReduceMax_Kernel_Fp16; + d.SumExp_Fp16 = softmax_neon::SumExp_Kernel_Fp16; + d.Softmax_Fp16 = softmax_neon::Softmax_Kernel_Fp16; + d.LogSoftmax_Fp16 = softmax_neon::LogSoftmax_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/softmax_kernel_neon.h b/src/lib/softmax_kernel_neon.h new file mode 100644 index 0000000..e362e5d --- /dev/null +++ b/src/lib/softmax_kernel_neon.h @@ -0,0 +1,40 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + softmax on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace softmax_neon { + +void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); + +void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Softcap); + +void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); + +MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N); + +MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum); + +void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Sum); + +void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum, const MLAS_FP16 LogSum); + +} // namespace rope_neon diff --git a/src/lib/softmax_kernel_neon_fp16.cpp b/src/lib/softmax_kernel_neon_fp16.cpp new file mode 100644 index 0000000..dfd65d9 --- /dev/null +++ b/src/lib/softmax_kernel_neon_fp16.cpp @@ -0,0 +1,917 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 softmax kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "softmax.h" +#include "softmax_kernel_neon.h" + +namespace softmax_neon { + +template +struct MlasExpConstants { + T LowerRange; + T UpperRange; + T LowerRangeSumExp; + T UpperRangeSumExp; + T RoundingBias; + T Log2Reciprocal; + T Log2High; + T Log2Mid; + T Log2Low; + T poly_0; + T poly_1; + T poly_2; + T poly_3; + T poly_4; + T poly_56; + T MinimumExponent; + T MaximumExponent; +}; + +constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = { + 0xcc55, // -25 * ln2 + 0x498c, // 16 * ln2 + 0xc95f, // -15.5 * ln2 + 0x495f, // 15.5 * ln2 + 0x6600, // 1.5 * 2^10 + 0x3dc5, // 1/ln2 + 0xb98b, // -6.9287109375e-1f16 + 0x8c85, // -2.758502960205078e-4f16 + 0x8004, // -2.384185791015625e-7f16 + 0x15b0, // 1/6! + 0x2044, // 1/5! + 0x2955, // 1/4! + 0x3155, // 1/3! + 0x3800, // 1/2! + 0x3c00, // 1/1! + 0xC800, // -14 + 0x3C00, // 15 +}; + +template +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants(); + +template <> +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x8 = { + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent), + }; + return ExpConstantsFp16x8; +} + +template <> +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x4 = { + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent), + }; + return ExpConstantsFp16x4; +} + +// Range reduction + polynomial approximation. Refer algorithm details to MlasComputeExpVector. +template +MLAS_FORCEINLINE +T Exp_Vector_Fp16(T x) { + const auto& constants = Get_Exp_Constants(); + auto clamped_x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); + + // integral + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); + + // residual + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); + + // handle overflow + auto overflow = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); + auto normal = overflow; + + auto minimum_exponent = MlasReinterpretFloat16AsInt16(constants.MinimumExponent); + auto maximum_exponent = MlasReinterpretFloat16AsInt16(constants.MaximumExponent); + normal = MlasClampInt16(normal, minimum_exponent, maximum_exponent); + + overflow = MlasSubtractInt16(overflow, normal); + overflow = MlasAddInt16(overflow, maximum_exponent); + normal = MlasAddInt16(normal, maximum_exponent); + + // polynomial approximation + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + + auto overflow_f = MlasReinterpretInt16AsFloat16(overflow); + r = MlasMultiplyFloat16(r, overflow_f); + p = MlasMultiplyAddFloat16(p, r, overflow_f); + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); + + return p; +} + +void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = Exp_Vector_Fp16(v0); + auto r1 = Exp_Vector_Fp16(v1); + auto r2 = Exp_Vector_Fp16(v2); + auto r3 = Exp_Vector_Fp16(v3); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = Exp_Vector_Fp16(v0); + auto r1 = Exp_Vector_Fp16(v1); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = Exp_Vector_Fp16(v0); + MlasStoreFloat16x8(output, r0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = Exp_Vector_Fp16(v0); + MlasStoreFloat16x4(output, r0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 1); + } +} + +// assume no overflow +template +MLAS_FORCEINLINE +T SumExp_Vector_Fp16(T x, T negative_maximum) { + const auto& constants = Get_Exp_Constants(); + auto clamped_x = MlasMaximumFloat16(MlasAddFloat16(x, negative_maximum), constants.LowerRangeSumExp); + + // integral + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); + + // residual + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); + + // 2^m + auto normal = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); + normal = MlasAddInt16(normal, MlasReinterpretFloat16AsInt16(constants.MaximumExponent)); + + // polynomial approximation + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); + + return p; +} + +MLAS_FORCEINLINE +float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, float16x8_t v4) { + v0 = MlasAddFloat16(v0, v1); + v2 = MlasAddFloat16(v2, v3); + return MlasAddFloat16(MlasAddFloat16(v0, v2), v4); +} + +MLAS_FORCEINLINE +float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2) { + return MlasAddFloat16(MlasAddFloat16(v0, v1), v2); +} + +MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + float16x8_t negative_maximum8 = MlasBroadcastFloat16x8(NegativeMaximum.val); + float16x4_t negative_maximum4 = MlasBroadcastFloat16x4(NegativeMaximum.val); + const bool store_output = Output != nullptr; + float16x8_t accumulator8 = MlasZeroFloat16x8(); + float16x4_t accumulator4 = MlasZeroFloat16x4(); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + auto r1 = SumExp_Vector_Fp16(v1, negative_maximum8); + auto r2 = SumExp_Vector_Fp16(v2, negative_maximum8); + auto r3 = SumExp_Vector_Fp16(v3, negative_maximum8); + + accumulator8 = AddUp(r0, r1, r2, r3, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + output += 32; + } + + input += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + auto r1 = SumExp_Vector_Fp16(v1, negative_maximum8); + + accumulator8 = AddUp(r0, r1, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + output += 16; + } + + input += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + accumulator8 = MlasAddFloat16(r0, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + output += 8; + } + + input += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + accumulator4 = MlasAddFloat16(r0, accumulator4); + + if (store_output) { + MlasStoreFloat16x4(output, r0); + output += 4; + } + + input += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 3); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 2); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 1); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 1)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } + + auto t = MlasAddFloat16(vget_low_f16(accumulator8), vget_high_f16(accumulator8)); + t = MlasAddFloat16(t, accumulator4); + _mlas_fp16_ result = MlasReduceAddFloat16(t); + return MLAS_FP16::FromBits(result); +} + +template +struct MlasTanhConstants { + T LowerRange; + T UpperRange; + T alpha_7; + T alpha_5; + T alpha_3; + T alpha_1; + T beta_6; + T beta_4; + T beta_2; + T beta_0; +}; + +constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = { + 0xc308, // -3.51562 + 0x4308, // 3.51562 + 0x0001, + 0x00f9, + 0x1138, + 0x1d03, + 0x0014, + 0x07c5, + 0x18a5, + 0x1d03, +}; + +template +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants(); + +template <> +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x8 = { + MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0), + }; + return TanhConstantsFp16x8; +} + +template <> +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x4 = { + MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0), + }; + return TanhConstantsFp16x4; +} + +// TODO(fajin): optimize polynomial coefficients +template +MLAS_FORCEINLINE +T Tanh_Vector_Fp16(T x) { + const auto& constants = Get_Tanh_Constants(); + x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); + + T x_2 = MlasMultiplyFloat16(x, x); + + T p = MlasMultiplyAddFloat16(constants.alpha_7, x_2, constants.alpha_5); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_3); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_1); + p = MlasMultiplyFloat16(p, x); + + T q = MlasMultiplyAddFloat16(constants.beta_6, x_2, constants.beta_4); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_2); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_0); + + return MlasDivideFloat16(p, q); +} + +void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = Tanh_Vector_Fp16(v0); + auto r1 = Tanh_Vector_Fp16(v1); + auto r2 = Tanh_Vector_Fp16(v2); + auto r3 = Tanh_Vector_Fp16(v3); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = Tanh_Vector_Fp16(v0); + auto r1 = Tanh_Vector_Fp16(v1); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStoreFloat16x8(output, r0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStoreFloat16x4(output, r0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 1); + } +} + +void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Softcap) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto softcap8 = MlasBroadcastFloat16x8(Softcap.val); + auto softcap4 = MlasBroadcastFloat16x4(Softcap.val); + auto one8 = MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00); + auto one4 = MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00); + auto softcap_reciprocal8 = MlasDivideFloat16(one8, softcap8); + auto softcap_reciprocal4 = MlasDivideFloat16(one4, softcap4); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); + v2 = MlasMultiplyFloat16(v2, softcap_reciprocal8); + v3 = MlasMultiplyFloat16(v3, softcap_reciprocal8); + + v0 = Tanh_Vector_Fp16(v0); + v1 = Tanh_Vector_Fp16(v1); + v2 = Tanh_Vector_Fp16(v2); + v3 = Tanh_Vector_Fp16(v3); + + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); + v2 = MlasMultiplyFloat16(v2, softcap8); + v3 = MlasMultiplyFloat16(v3, softcap8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); + + v0 = Tanh_Vector_Fp16(v0); + v1 = Tanh_Vector_Fp16(v1); + + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { + const auto* input = reinterpret_cast(Input); + auto max8 = MlasBroadcastFloat16x8((_mlas_fp16_)0xfbff); + auto max4 = MlasBroadcastFloat16x4((_mlas_fp16_)0xfbff); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMaximumFloat16(v0, v1); + v2 = MlasMaximumFloat16(v2, v3); + v0 = MlasMaximumFloat16(v0, v2); + max8 = MlasMaximumFloat16(max8, v0); + + input += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMaximumFloat16(v0, v1); + max8 = MlasMaximumFloat16(max8, v0); + + input += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + max8 = MlasMaximumFloat16(max8, v0); + + input += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + max4 = MlasMaximumFloat16(max4, v0); + + input += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + max4 = MlasMaximumFloat16(max4, v0); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + max4 = MlasMaximumFloat16(max4, v0); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 1)); + max4 = MlasMaximumFloat16(max4, v0); + } + + auto t = MlasMaximumFloat16(vget_low_f16(max8), vget_high_f16(max8)); + t = MlasMaximumFloat16(t, max4); + _mlas_fp16_ result = MlasReduceMaximumFloat16(t); + + return MLAS_FP16::FromBits(result); +} + +void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Sum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto sum8 = MlasBroadcastFloat16x8(Sum.val); + auto sum4 = MlasBroadcastFloat16x4(Sum.val); + auto scale8 = MlasDivideFloat16(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8); + auto scale4 = MlasDivideFloat16(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); + v2 = MlasMultiplyFloat16(v2, scale8); + v3 = MlasMultiplyFloat16(v3, scale8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasMultiplyFloat16(v0, scale8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum, const MLAS_FP16 LogSum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto negative_maximum8 = MlasBroadcastFloat16x8(NegativeMaximum.val); + auto negative_maximum4 = MlasBroadcastFloat16x4(NegativeMaximum.val); + auto log_sum8 = MlasBroadcastFloat16x8(LogSum.val); + auto log_sum4 = MlasBroadcastFloat16x4(LogSum.val); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); + v2 = MlasAddFloat16(v2, negative_maximum8); + v3 = MlasAddFloat16(v3, negative_maximum8); + + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); + v2 = MlasSubtractFloat16(v2, log_sum8); + v3 = MlasSubtractFloat16(v3, log_sum8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); + + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasAddFloat16(v0, negative_maximum8); + v0 = MlasSubtractFloat16(v0, log_sum8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +} // namespace rope_neon diff --git a/src/lib/spool_kernel_neon.cpp b/src/lib/spool_kernel_neon.cpp new file mode 100644 index 0000000..5883625 --- /dev/null +++ b/src/lib/spool_kernel_neon.cpp @@ -0,0 +1,293 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + spool_kernel_neon.cpp + +Abstract: + + This module implements the single precision pooling kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#include "mlasi.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +void + MLASCALL + MlasPoolMaximumFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); + MLAS_UNREFERENCED_PARAMETER(InputStride); + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const float MaxPaddingValue = std::numeric_limits::lowest(); + + const MLAS_FLOAT32X4 MaxPaddingVector = MlasBroadcastFloat32x4(MaxPaddingValue); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 MaxVector0 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector1 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector2 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector3 = MaxPaddingVector; + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + } else { + values[i] = MaxPaddingValue; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } + } + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], MaxVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], MaxVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], MaxVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], MaxVector3); + } +} + +static void +MlasPoolAverageFloatKernelNeonImpl( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + bool ExcludePad +) +{ + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const MLAS_FLOAT32X4 ZeroVector = MlasZeroFloat32x4(); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 SumVector0 = ZeroVector; + MLAS_FLOAT32X4 SumVector1 = ZeroVector; + MLAS_FLOAT32X4 SumVector2 = ZeroVector; + MLAS_FLOAT32X4 SumVector3 = ZeroVector; + + std::vector valid_count; + if (ExcludePad) { + valid_count.resize(BlockSize, 0); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + + if (ExcludePad) { + for (size_t i = 0; i < BlockSize; i++) { + valid_count[i]++; + } + } + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + if (ExcludePad) { + valid_count[i]++; + } + } else { + values[i] = 0.0f; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + } + } + } + + if (ExcludePad) { + float results[BlockSize]; + + MlasStoreFloat32x4(&results[0], SumVector0); + MlasStoreFloat32x4(&results[4], SumVector1); + MlasStoreFloat32x4(&results[8], SumVector2); + MlasStoreFloat32x4(&results[12], SumVector3); + + for (size_t i = 0; i < BlockSize; i++) { + results[i] = results[i] / static_cast(valid_count[i]); + } + + MLAS_FLOAT32X4 ResultVector0 = MlasLoadFloat32x4(&results[0]); + MLAS_FLOAT32X4 ResultVector1 = MlasLoadFloat32x4(&results[4]); + MLAS_FLOAT32X4 ResultVector2 = MlasLoadFloat32x4(&results[8]); + MLAS_FLOAT32X4 ResultVector3 = MlasLoadFloat32x4(&results[12]); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } else { + const float KernelSize = static_cast(ActualKernelSize); + const MLAS_FLOAT32X4 KernelSizeVector = MlasBroadcastFloat32x4(KernelSize); + + MLAS_FLOAT32X4 ResultVector0 = MlasDivideFloat32x4(SumVector0, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector1 = MlasDivideFloat32x4(SumVector1, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector2 = MlasDivideFloat32x4(SumVector2, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector3 = MlasDivideFloat32x4(SumVector3, KernelSizeVector); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } + } +} + +void + MLASCALL + MlasPoolAverageExcludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + true // ExcludePad = true + ); +} + +void + MLASCALL + MlasPoolAverageIncludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + false // ExcludePad = false + ); +} + +#endif diff --git a/src/lib/sqnbitgemm.cpp b/src/lib/sqnbitgemm.cpp deleted file mode 100644 index b45f3a1..0000000 --- a/src/lib/sqnbitgemm.cpp +++ /dev/null @@ -1,772 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, - as well as some SQNBitGemm-related query functions. ---*/ - -#include "sqnbitgemm.h" -#include "sqnbitgemm_q8_block.h" - -#include -#include - -namespace -{ - -enum SQNBitGemmVariant { - SQNBitGemmVariantInvalid = -1, - - // Valid variants - - SQNBitGemmVariant_BitWidth4_CompFp32 = 0, - SQNBitGemmVariant_BitWidth4_CompInt8, - - // End of valid variants - - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. - // Its value is used as an array size. - SQNBitGemmVariantCount, -}; - -SQNBitGemmVariant -GetSQNBitGemmVariant( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } - } - - return SQNBitGemmVariantInvalid; -} - -} // namespace - -bool MLASCALL -MlasIsSQNBitGemmAvailable( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return false; - } - - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); - - switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: { - return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; - } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 - return - (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || - (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); - } - default: { - return false; - } - } -} - -namespace -{ - -size_t -SQNBitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 0; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); - } - - return 0; -} - -size_t -SQNBitGemmPerGemmWorkspaceAlignment( - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 1; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); - } - - return 1; -} - -size_t -SQNBitGemmPerGemmWorkspaceStride( - size_t M, - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - return MlasDivRoundup(Size, Alignment) * Alignment; -} - -} // namespace - -size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (PerGemmWorkspaceStride == 0) { - return 0; - } - - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - - const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; - - return WorkspaceSize + Alignment - 1; -} - -size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return 0; - } - - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( - N, K, BlkLen, ComputeType - ); - } - - return 0; -} - -struct PerGemmQuantAWorkspace { - PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) - : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) - { - QuantData = (std::byte*)PerGemmWorkspace; - QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); - BlockSum = QuantScale + M * BlockCountK; - } - std::byte* QuantData; // NxBlockCountKxBlkLen - float* QuantScale; // NxBlockCountK - float* BlockSum; // NxBlockCountK - void* PerGemmWorkspace_; // memory for above data - size_t M_, BlockCountK_, BlkLen_; -}; - -void MLASCALL -MlasSQNBitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkBitWidth, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const void* QuantBData, - void* PackedQuantBDataAndOrBlkSumWorkspace, - const void* QuantBScale, - bool has_zp_input, - const void* QuantBZeroPoint, - MLAS_THREADPOOL* ThreadPool -) -{ - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; - if (Dispatch == nullptr) { - return; - } - - if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); - Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(QuantBScale), - has_zp_input, - static_cast(QuantBZeroPoint), - packed_quant_b, - ThreadPool - ); - } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. - //assert(QuantBScale == nullptr); - //assert(QuantBZeroPoint == nullptr); - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBDataAndOrBlkSumWorkspace), - ThreadPool - ); - return; - } - } -} - -namespace -{ - -MLAS_FORCEINLINE void -AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) -{ - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); - acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); - MlasStoreFloat32x4(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } -} - -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - -void -SQ4BitGemm_CompFp32( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ - constexpr size_t BlkBitWidth = 4; - - MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); - - const size_t lda = DataParams->lda; - const size_t ldc = DataParams->ldc; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const float* A = DataParams->A + RangeStartM * lda; - - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; - - if (RangeCountM == 1) { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* a_row = A; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } - return; - } - - constexpr size_t StrideN = 32; - size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); - MlasThreadedBufAlloc(bufsize); - auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); - - // - // Step through each slice of matrix B along the N dimension. - // - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, StrideN); - - // - // Step through each slice of matrix A along the M dimension. - // - const float* a_row = A; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( - BlkLen, - dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks - ); - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true - ); -#else - auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); -#endif - - if (bias) { - AddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); - } - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc - ); - } - - c_blk += ldc * RowsHandled; - a_row += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } - } -} - -void -SQ4BitGemm_CompInt8( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ -#ifdef MLAS_TARGET_AMD64_IX86 - PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); - constexpr size_t BlkBitWidth = 4; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - - // quant A scale is embedded in QuantData if QuantScale is nullptr. - const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); - const size_t ldc = DataParams->ldc; - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; - const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; - - assert(RangeStartN % 4 == 0); - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; - const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; -#else - constexpr size_t BlkBitWidth = 4; - - const size_t k_blks = MlasDivRoundup(K, BlkLen); - - const size_t lda = k_blks * Q8BlkSize(BlkLen); - const size_t ldc = DataParams->ldc; - const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - - const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; - - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; - const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; - const std::byte* QuantBZeroPoint = - (DataParams->QuantBZeroPoint == nullptr) - ? nullptr - : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; -#endif - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, - RowsHandled, CountN, ldc - ); - } - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; - } - } -#ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) - { - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( - BlkLen, - QuantA, - QuantAScale, - b_col, - b_col_scale, - b_col_zp, - c_blk, - RangeCountM, - CountN, - K, - k_blks, - bias, - ldc, - ABlockSum, - b_blk_sum - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - } -#endif - } -} - -typedef void(InitializeWorkspaceFn)( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool -); - -void -InitializeWorkspace_CompInt8( - size_t M, - size_t N, - size_t K, - size_t BatchN, - size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool -) -{ - MLAS_UNREFERENCED_PARAMETER(N); - - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - - // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); - } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - - void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); - std::byte* QuantARowPtr = quant_a_data.QuantData; - float* QuantARowScalePtr = quant_a_data.QuantScale; - float* QuantARowBlkSum = quant_a_data.BlockSum; - for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); - ARowPtr += data.lda; - QuantARowPtr += BlockCountK * BlkLen; - QuantARowScalePtr += BlockCountK; - QuantARowBlkSum += BlockCountK; - } - }); - } -} - -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; - -constexpr auto OperationMap = []() { - std::array ops; - - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; - - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; - - return ops; -}(); -} // namespace - -void MLASCALL -MlasSQNBitGemmBatch( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const size_t BlkBitWidth, - const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* Workspace, - MLAS_THREADPOOL* ThreadPool -) -{ - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); - assert(Variant != SQNBitGemmVariantInvalid); - - // - // Ensure `Workspace` has correct alignment. - // - if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); - const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); - Workspace = reinterpret_cast( - (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) - ); - } - - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; - InitializeWorkspaceOperation != nullptr) { - InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool - ); - } - - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - - if (ThreadPool == nullptr) { - for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { - const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); - } else { - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); - } - } - return; - } - - // - // Compute the number of target threads given the complexity of the SGEMM - // operation. Small requests should run using the single threaded path. - // - - const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - - ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; - if (ThreadsPerGemm < 1) { - ThreadsPerGemm = 1; - } - - constexpr size_t StrideM = 128; - - size_t nc = N; - if (ThreadsPerGemm > 1) { - // more than one thread per GEMM - - const size_t BlockedM = MlasDivRoundup(M, StrideM); - const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min( - nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * - MLAS_QGEMM_STRIDEN_THREAD_ALIGN - ); - } - } - const size_t StrideN = nc; - - const size_t ThreadCountM = MlasDivRoundup(M, StrideM); - const size_t ThreadCountN = MlasDivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; - - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - const auto gemm_i = tid / ThreadsPerGemm; - const auto blk_i = tid % ThreadsPerGemm; - const auto* Data = &DataParams[gemm_i]; - - const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; - const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; - - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); - - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - } else { - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - } - }); -} diff --git a/src/lib/sqnbitgemm_kernel_avx2.cpp b/src/lib/sqnbitgemm_kernel_avx2.cpp index abf8060..6416d25 100644 --- a/src/lib/sqnbitgemm_kernel_avx2.cpp +++ b/src/lib/sqnbitgemm_kernel_avx2.cpp @@ -18,9 +18,8 @@ Module Name: #include #include #include -#include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" @@ -29,6 +28,7 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +#include void MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) @@ -585,6 +585,89 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2( return CountM; } +template +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + size_t SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( const size_t BlkLen, @@ -1307,12 +1390,39 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -1320,50 +1430,56 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // TODO: always use SubBlkLen = 64 in CompInt8 size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 80d6780..aec5dc9 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -3,9 +3,16 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen16[40], 32) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000001, 0x00000001, 0x00000001, 0x00000001 +}; MLAS_FORCEINLINE __m256 load_and_broadcast_4_scale_2(const float* scale) @@ -152,6 +159,208 @@ accumulate_blklen16_r2c1blk4_avx2( scale_a0, scale_a1, scale_b, acc0, acc1); } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + __m256 scale_a1_4_ps = _mm256_broadcast_ps((const __m128*)scale_a1); + __m256 scale_a1b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a1_4_ps); + __m256 scale_a1b_4_shuffle_ps = _mm256_permutevar_ps(scale_a1b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk1_avx2( + const __m128i& av00_16_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m256& acc0 +) +{ + const __m128i bv0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantBDataPtr)); + __m256 scale_a0b_1_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m128i dot00_4_epi32 = _mm_dpbusds_avx_epi32(_mm_setzero_si128(), bv0_16_epi8, av00_16_epi8); + const __m256i dot00_8_epi32 = _mm256_cvtepu32_epi64(dot00_4_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } + else +#endif + { + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + const __m256i bv0_32_epi8 = _mm256_cvtepu8_epi16(bv0_16_epi8); + const __m256i av00_32_epi8 = _mm256_cvtepu8_epi16(av00_16_epi8); + const __m256i dot00_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot00_8_epi32 = _mm256_madd_epi16(one_mask, dot00_16_epi16); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk4_avx2( const __m256i& av0_32_epi8, @@ -332,6 +541,118 @@ Q4Int8GemmR2xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av_10_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc[NCols4]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a10 * scale_b1, acc[NCols4 + 1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a10 * scale_b2, acc[NCols4 + 2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a10 * scale_b3, acc[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr+= NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( const std::byte* QuantA, const float* QuantAScale, @@ -437,6 +758,108 @@ void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0 + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc0); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx2( const std::byte* QuantA, @@ -549,6 +972,106 @@ Q4Int8GemmR1xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx2( const std::byte* QuantA, @@ -634,6 +1157,90 @@ Q4Int8GemmR1xC1BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_a0b = scale_a00 * (*QuantBScalePtr); + accumulate_q8_blklen16_r1c1blk1_avx2(av_16_epi8, QuantBDataPtr, scale_a0b, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen16Avx2( @@ -725,3 +1332,95 @@ MLAS_FORCEINLINE return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index af6f520..a745dd9 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -3,9 +3,14 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen32[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; MLAS_FORCEINLINE void accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, @@ -115,6 +120,156 @@ accumulate_blklen32_r2c1blk2_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_a1b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a1_2_ps); + __m256 scale1_8_ps = _mm256_shuffle_ps(scale_a1b_2_ps, scale_a1b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11 00 11 + const __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11, 00 11 + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( @@ -196,6 +351,100 @@ accumulate_blklen32_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv0_32_epi8, combined_scale10, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + __m256 dot10_8_ps = _mm256_cvtepi32_ps(dot10_8_epi32); + acc1 = _mm256_fmadd_ps(dot10_8_ps, _mm256_set1_ps(combined_scale10), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( @@ -367,6 +616,116 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + + float scale_a00 = *QuantAScalePtr; + float scale_a10 = *(QuantAScalePtr + BlockCountK); + + float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + + float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc[1], acc[NCols4 + 1]); + + float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc[2], acc[NCols4 + 2]); + + float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc[3], acc[NCols4 + 3]); + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, @@ -460,6 +819,95 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( @@ -589,6 +1037,100 @@ Q4Int8GemmXx4BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc[3]); + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( @@ -672,6 +1214,81 @@ Q4Int8GemmXxXBlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes16; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -765,6 +1382,98 @@ MLAS_FORCEINLINE return CountM; } +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + // this function is to explore larger NCols. With Avx2 it does not improve performance. // Leave it here until the same is implemented in avx512. template accumulator> @@ -951,7 +1660,24 @@ MlasQ4Int8TileGemmKernelBlkLen32Avx2( if constexpr (NCols4 == 8) { __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + // Clang is not happy with the code here, even if constexpr `NCols4 == 8` is always false in this context: + // + // In file included from .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp:26: + // .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1663:49: error: array index 4 is past the end of the array (that has type '__m256[4]') [-Werror,-Warray-bounds] + // 1663 | __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + // | ^ ~ + // .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1531:13: note: array 'acc' declared here + // 1531 | __m256 acc[NCols4]; + // | ^ +#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warray-bounds" +#endif __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); +#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#pragma clang diagnostic pop +#endif if (BiasPtr != nullptr) { acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 174ebc5..2058374 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -3,9 +3,15 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen64[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; + template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx2( @@ -76,6 +82,141 @@ accumulate_blklen64_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, + float scale_a0b, + __m256& acc0 +) +{ + __m256 scale_8_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, + float scale_a0b, + float scale_a1b, + __m256& acc0, + __m256& acc1 +) +{ + __m256 scale0_8_ps = _mm256_set1_ps(scale_a0b); + __m256 scale1_8_ps = _mm256_set1_ps(scale_a1b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum0_8_epi32 = _mm256_dpbusds_avx_epi32(sum0_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + __m256i sum1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum1_8_epi32 = _mm256_dpbusds_avx_epi32(sum1_8_epi32, bv1_32_epi8, av11_32_epi8); + __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else + #endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_add_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx2( @@ -117,7 +258,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -212,6 +353,138 @@ Q4Int8GemmR2xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a1b1 = (*(QuantBScalePtr + 1)) * scale_a1; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a1b2 = (*(QuantBScalePtr + 2)) * scale_a1; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + const float scale_a1b3 = (*(QuantBScalePtr + 3)) * scale_a1; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( @@ -292,6 +565,109 @@ Q4Int8GemmR2xC1BlkLen64Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( @@ -371,6 +747,120 @@ Q4Int8GemmR1xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( @@ -447,6 +937,97 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx2( @@ -539,3 +1120,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index 127279a..bfb2959 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -18,8 +18,9 @@ Module Name: #include #include #include -#include -#include "sqnbitgemm.h" +#include + +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx512_int8_blklen16.h" @@ -28,7 +29,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // #include "sqnbitgemm_kernel_avx_common_fp32.h" @@ -151,7 +152,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( } // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // MLAS_FORCEINLINE @@ -247,6 +248,100 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -332,12 +427,38 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -346,26 +467,30 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8.h b/src/lib/sqnbitgemm_kernel_avx512_int8.h index 7d9dc36..8f1ea66 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -81,7 +81,7 @@ accumulate_blklen32_r2c1blk2_avx2( _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -143,7 +143,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); @@ -184,7 +184,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -435,7 +435,7 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } -template +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, @@ -877,7 +877,7 @@ MLAS_FORCEINLINE QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, - remainingCols, + remainingCols, BlockCountK, Bias ? Bias + multipleCols : nullptr, lda, diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 60a8873..6e8cebe 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" @@ -35,6 +35,13 @@ // bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); //} +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen128[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void dot_accumulate_1blk( const __m512i& bv0_64_epi8, @@ -139,6 +146,117 @@ accumulate_blklen128_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + float scale_a1b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, scale_a1b, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + + // row 1 + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(dot10_16_epi32, dot11_16_epi32); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_set1_ps(scale_a1b), acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen128Avx512( @@ -251,6 +369,110 @@ Q4Int8GemmR2xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + const float scale_a1b1 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 1)); + const float scale_a1b2 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 2)); + const float scale_a1b3 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen128Avx512( @@ -332,6 +554,89 @@ Q4Int8GemmR2xC1BlkLen128Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen128Avx512( @@ -411,6 +716,90 @@ Q4Int8GemmR1xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc[0]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, acc[2]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen128Avx512( @@ -487,6 +876,82 @@ Q4Int8GemmR1xC1BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen128Avx512( @@ -579,3 +1044,97 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index bb14bab..b720c45 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -3,13 +3,20 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen16[48], 64) = { + 0x00000000, 0x00000000, 0x00000004, 0x00000004, 0x00000001, 0x00000001, 0x00000005, 0x00000005, + 0x00000002, 0x00000002, 0x00000006, 0x00000006, 0x00000003, 0x00000003, 0x00000007, 0x00000007, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) @@ -120,6 +127,131 @@ accumulate_blklen16_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a1b_ps, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 0000111122223333 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 4444555566667777 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk8_avx512vnni( const __m512i& av0_64_epi8, @@ -214,6 +346,36 @@ accumulate_blklen16_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 01234567 00000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen16Avx512( @@ -399,6 +561,152 @@ Q4Int8GemmR2xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = BlkDataSizeInBytes * PerAccuBlk8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[NCols4 + 1]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[NCols4 + 2]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + // In A, the bytes beyond K has set to 0. + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_a0b0 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_a1b0 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a0b0, acc2[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a1b0, acc2[NCols4]); + + const float scale_a0b1 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_a1b1 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a0b1, acc2[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a1b1, acc2[NCols4 + 1]); + + const float scale_a0b2 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_a1b2 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a0b2, acc2[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a1b2, acc2[NCols4 + 2]); + + const float scale_a0b3 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float scale_a1b3 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a0b3, acc2[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a1b3, acc2[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen16Avx512( @@ -509,6 +817,108 @@ Q4Int8GemmR2C1BlkLen16Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_00, acc20); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_10, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx512( @@ -628,6 +1038,118 @@ Q4Int8GemmR1xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = PerAccuBlk8 * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc2[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx512( @@ -719,6 +1241,94 @@ Q4Int8GemmR1xC1BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -810,3 +1420,94 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e9df6b9..f630883 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -3,11 +3,20 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen32[48], 64) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000000, 0x00000000, 0x00000002, 0x00000002, + 0x00000001, 0x00000001, 0x00000003, 0x00000003, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) { @@ -115,6 +124,139 @@ accumulate_blklen32_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a1b_ps, 0); + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 00220022 11331133 + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 00000000 11111111 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 22222222 33333333 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 00220022 11331133 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk4_avx512vnni( const __m512i& av0_64_epi8, @@ -203,6 +345,38 @@ accumulate_blklen32_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + MLAS_FORCEINLINE void accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { @@ -256,6 +430,44 @@ accumulate_blklen32_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_q8_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_q8_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen32Avx512( @@ -437,6 +649,142 @@ Q4Int8GemmR2xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[NCols4 + 1]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[NCols4 + 2]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc2[1], acc2[NCols4 + 1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc2[2], acc2[NCols4 + 2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc2[3], acc2[NCols4 + 3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen32Avx512( @@ -548,8 +896,8 @@ Q4Int8GemmR2C1BlkLen32Avx512( } template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen32Avx512( +void MLAS_FORCEINLINE +Q8Int8GemmR2C1BlkLen32Avx512( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -559,69 +907,170 @@ Q4Int8GemmR1xC4BlkLen32Avx512( size_t CountN, size_t BlockCountK, const float* Bias, - size_t ldc -) + size_t ldc) { constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); constexpr size_t PerAccuBlk4 = 4; const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); - for (size_t m = 0; m < CountM; m++) { + for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; + float* SumPtr = C + m * ldc; - for (size_t n = 0; n < CountN; n += NCols4) { + for (size_t n = 0; n < CountN; n++) { const std::byte* QuantAPtr = QuantA + m * lda; const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - __m512 acc[NCols4] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); if constexpr (vnni) { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); } else { - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); } + // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; - QuantBScalePtr += PerAccuBlk4 * NCols4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; } - __m256 acc2[NCols4] = { - h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) - }; - - while (k_blks_remaining-- > 0) { + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); @@ -667,6 +1116,114 @@ Q4Int8GemmR1xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_03, acc2[3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen32Avx512( @@ -759,6 +1316,94 @@ Q4Int8GemmR1xC1BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -850,3 +1495,93 @@ MlasQ4Int8GemmKernelBlkLen32Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2a65ac4..68bf1da 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -3,9 +3,16 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen64[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE __m256 h_add_512(__m512 a) { @@ -125,7 +132,7 @@ dot_accumulate_2blkvnni( __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); - __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 1 0 1 0 1 0 1... __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); @@ -182,6 +189,146 @@ accumulate_blklen64_r2c1blk2_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + dot_accumulate_2blkvnni(av10_64_epi8, av11_64_epi8, scale_a1, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + const __m512 scale_a1_16_ps = _mm512_broadcast_f32x8(scale_a1_ps); + const __m512 scale_a1b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a1_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i t12 = _mm512_unpackhi_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk2_avx512( @@ -283,6 +430,112 @@ accumulate_blklen64_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx512( + const __m512i& av0_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + // row 1 + __m512i dot1_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av1_64_epi8); + __m512i dot1_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av1_64_epi8); + __m512i dot1_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_low_32_epi16); + __m512i dot1_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_high_32_epi16); + __m512i dot1_16_epi32 = _mm512_add_epi32(dot1_low_16_epi32, dot1_high_16_epi32); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx512( @@ -448,6 +701,106 @@ Q4Int8GemmR2xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen64); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx512( @@ -540,6 +893,95 @@ Q4Int8GemmR2xC1BlkLen64Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx512( @@ -633,6 +1075,94 @@ Q4Int8GemmR1xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx512( @@ -718,6 +1248,88 @@ Q4Int8GemmR1xC1BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_q8_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx512( @@ -838,3 +1450,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6a5c011..e172308 100644 --- a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" @@ -299,6 +299,100 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -314,12 +408,38 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -328,29 +448,33 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512vnni; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx_common.h b/src/lib/sqnbitgemm_kernel_avx_common.h index 177f551..36c15cd 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common.h +++ b/src/lib/sqnbitgemm_kernel_avx_common.h @@ -1,28 +1,29 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" // // Quantized B data packing function implementation. // +template static size_t -SQ4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool /* HasZeroPoint */, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - // _mm256_load_si256 requires alignment on a 32-byte boundary - constexpr size_t PackedQuantBDataAlignment = 32; + // avx512 requires alignment on a 64-byte boundary + constexpr size_t PackedQuantBDataAlignment = 64; PackedQuantBDataSize += PackedQuantBDataAlignment - 1; constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); BlkSumSize += BlkSumAlignment - 1; @@ -39,7 +40,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -246,6 +247,60 @@ PackQuantB( ); } +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkSize; + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + //#include static void @@ -294,6 +349,61 @@ ComputePackBlkSum( ); } +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + + // re-arrange scale to the same order as packed data + if (n_4 + 4 > N) { + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + *(QuantBScaleBegin + n_4 * BlockCountK + k_blk * 4 + n_res) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + size_t dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = n_4 * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = n_4 * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset) = QuantBScale; + } + }); +} + static void PackQuantBDataAndBlkSum( size_t N, @@ -302,22 +412,49 @@ PackQuantBDataAndBlkSum( size_t SubBlkLen, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + +static void +Q8PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { if (QuantBDataBegin) { - PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); } if (QuantBScaleBegin) { - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); } - if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { - ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); } } @@ -326,18 +463,20 @@ PackQuantBDataAndBlkSum( // static size_t -SQ4BitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + bool /* HasZeroPoint */, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t /* BlkBitWidth */ ) { MLAS_UNREFERENCED_PARAMETER(N); switch(ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // QuantData + Scale + BlkSum @@ -351,15 +490,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { diff --git a/src/lib/sqnbitgemm_kernel_avx_common_fp32.h b/src/lib/sqnbitgemm_kernel_avx_common_fp32.h index 5cd380e..d15cfc7 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ b/src/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" template MLAS_FORCEINLINE diff --git a/src/lib/sqnbitgemm_kernel_avx_common_int8.h b/src/lib/sqnbitgemm_kernel_avx_common_int8.h index 895ce6c..2e96082 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/src/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" diff --git a/src/lib/sqnbitgemm_kernel_lasx.cpp b/src/lib/sqnbitgemm_kernel_lasx.cpp new file mode 100644 index 0000000..04c6540 --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_lasx.cpp @@ -0,0 +1,1089 @@ +/*++ + +Module Name: + + sqnbitgemm_kernel_lasx.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for loongarch64. Accelerate inference + optimization using lasx/lsx vector instruction sets. + +--*/ + +#include +#include + +#include +#include +#include +#include +#include "core/common/safeint.h" + +#include "qnbitgemm.h" +#include "sqnbitgemm_kernel_lasx_common.h" + +// 1. qnbitgemm.h->Q4BitGemmPackQuantBDataSize +template +static size_t +QNBitGemmPackQuantBDataSize_Lasx( + size_t N, + size_t K, + size_t BlkLen, + bool /* HasZeroPoint */, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ComputeType == SQNBIT_CompInt8) { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt ScaleSize = SafeInt(N) * BlockCountK * sizeof(float); + SafeInt BlkSumSize = SafeInt(BlockCountK) * MlasDivRoundup(N, 16) * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += SafeInt(PackedQuantBDataAlignment) - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += SafeInt(BlkSumAlignment) - 1; + + PackedQuantBDataSize += ScaleSize + BlkSumSize; + return PackedQuantBDataSize.Value(); + } else { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize.Value(); + } +} + +// 2. qnbitgemm.h->SQ4BitGemmPackQuantBData +static void +SQ4BitGemmPackQuantBData_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt Iterations = SafeInt(N) * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations.Value(), + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset.Value(); + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value(); + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +// 3. qnbitgemm.h->SQ4BitGemmPackQuantBDataAndBlkSum +static void +SQ4BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + SafeInt offset = SafeInt(N) * BlockCountK; + std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum_Lasx( + BlkLen, SubBlkLen, N, + packed_quant_b.PackedQuantBScale, + QuantBZPBegin, + packed_quant_b.QuantBBlkSum, + ThreadPool, + BlockCountK + ); + } +} + +// 3. qnbitgemm.h->SQ8BitGemmPackQuantBDataAndBlkSum +static void +SQ8BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + Q8PackQuantBDataAndBlkSum_lasx(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +MLAS_FORCEINLINE +__m256 +load_float_n_lasx(const float* data, int n) +{ + if (n <= 0) { + alignas(32) float zero_array[8] = {0}; + return (__m256)__lasx_xvld((void*)&zero_array, 0); + } + alignas(32) float buf[8] = {0}; + if (n > 0 && n <= 8) { + for (int i = 0; i < n; ++i) { + buf[i] = data[i]; + } + } + return (__m256)__lasx_xvld((void*)&buf, 0); +} + +// ComputeDotProducts_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen32 = 32; + constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); + static_assert(SubBlkStep16 == 16); + + __m256 acc[NCols]; + + alignas(32) static const float zero_array[8] = {0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] int count_half_4 = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + // TODO: Improve Memory Access Performance with Prefetching Matrix Operations + // alignas(32) float a_buf[2][32] = {0.0}; + //__m256 a_buf[8]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + // not ready for "Manual conversion to float" in neon yet. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { + size_t kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); + + __m256 av0_8_ps = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av1_8_ps = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + __m256 av2_8_ps = load_float_n_lasx(ARowPtr + k + kk + 16, std::min(kklen > 16 ? kklen - 16 : 0, 8)); + __m256 av3_8_ps = load_float_n_lasx(ARowPtr + k + kk + 24, std::min(kklen > 24 ? kklen - 24 : 0, 8)); + + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); + } + + UnrolledLoop([&](size_t i) { + __m256i bv_0_32; + + if constexpr (IsBlkLen64Layout) { + __m256i bv_32_4bit_tmp = __lasx_xvld(b_blk_data_col_ptr[i], 0); + if (!count_half_4) + bv_0_32 = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + bv_0_32 = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; + } else { + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_blk_data_col_ptr[i], 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + bv_0_32 = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + b_blk_data_col_ptr[i] += SubBlkStep16; + } + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + bv_0_32 = __lasx_xvsub_b(bv_0_32, zp); + + // (1)8bit -> 16bit + __m256i bv_0_15 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0x72)); + __m256i bv_16_31 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0xD8)); + + // (2)16bit -> int32 + __m256i bv_0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0x72)); + __m256i bv_8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0xD8)); + __m256i bv_16_23 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0x72)); + __m256i bv_24_31 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0xD8)); + + // (3)int32 -> fp32 + __m256 fbv_0_7 = __lasx_xvffint_s_w(bv_0_7); + __m256 fbv_8_15 = __lasx_xvffint_s_w(bv_8_15); + __m256 fbv_16_23 = __lasx_xvffint_s_w(bv_16_23); + __m256 fbv_24_31 = __lasx_xvffint_s_w(bv_24_31); + + __m256 scale_ps = (__m256)__lasx_xvldrepl_w(&scale_v[i], 0); + + fbv_0_7 = __lasx_xvfmul_s(fbv_0_7, scale_ps); + fbv_8_15 = __lasx_xvfmul_s(fbv_8_15, scale_ps); + fbv_16_23 = __lasx_xvfmul_s(fbv_16_23, scale_ps); + fbv_24_31 = __lasx_xvfmul_s(fbv_24_31, scale_ps); + + acc[i] = __lasx_xvfmadd_s(fbv_0_7, av0_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_8_15, av1_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_16_23, av2_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_24_31, av3_8_ps, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = hsum_float_8_lasx(acc[i]); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// ComputeDotProducts_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen16_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen16 = 16; + constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); + static_assert(SubBlkStep8 == 8); + + __m256 acc[NCols]; + alignas(32) int zero_array[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { + size_t kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); + + __m256 av_lo = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av_hi = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + + UnrolledLoop([&](size_t i) { + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_blk_data_col_ptr[i], 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + + /* + low->high + | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + */ + + b_blk_data_col_ptr[i] += SubBlkStep8; + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zp); + __m256i bv0_15 = __lasx_xvexth_h_b(packb); + + __m256i bv0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0x72)); + __m256i bv8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0xD8)); + + __m256 fbv0_7 = __lasx_xvffint_s_w(bv0_7); + __m256 fbv8_15 = __lasx_xvffint_s_w(bv8_15); + __m256 scale = (__m256)__lasx_xvldrepl_w((void*)&scale_v[i], 0); + fbv0_7 = __lasx_xvfmul_s(fbv0_7, scale); + fbv8_15 = __lasx_xvfmul_s(fbv8_15, scale); + + acc[i] = __lasx_xvfmadd_s(av_lo, fbv0_7, acc[i]); + acc[i] = __lasx_xvfmadd_s(av_hi, fbv8_15, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = 0.0f; + alignas(32) float acc_buf[8]; + __lasx_xvst(acc[i], (void*)&acc_buf, 0); + UnrolledLoop<8>([&](size_t j) { sum += acc_buf[j]; }); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = (CountN) - NCols4; + while (nblk >= 0) { + ComputeDotProducts_BlkLen16_CompFp32_lasx( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkLen16_CompFp32_lasx<1, HasZeroPoint>( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +// SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than NCols + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, true>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, false>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_Lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + + constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + for (size_t col = 0; col < CountN; col += NCols8) { + const int cols = std::min((int)NCols8, (int)CountN - (int)col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + __m256i weight_16_epi16[NCols8]; + __m256 scale_8_ps[NCols8]; + UnrolledLoop([&](size_t col_) { + if ((int)col_ < cols) { + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_data_ptr + col_ * b_data_col_stride_in_bytes, 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + // low->high + // | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + // | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + packb = __lasx_xvsub_b(packb, zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zero_point); + } + weight_16_epi16[col_] = __lasx_xvexth_h_b(packb); + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_16_epi16[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if ((int)col_ < cols) { + if (i_of_2 == 0) { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } else { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + } + } +} + +template +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t GemmFloatKernelWidth16 = 16; + constexpr size_t SubblkLen32 = 32; + + const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] int count_half_4 = 0; + + for (size_t col = 0; col < CountN; col += NCols8) { + // TODO: handle last tile with cols < NCols8 + const size_t cols = std::min(NCols8, CountN - col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { + __m256i weight_32_epi8[NCols8]; + __m256 scale_8_ps[NCols8]; + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (subblk % 2); + } + UnrolledLoop([&](size_t col_) { + // 1. load 32 4-bit data + if (col_ < cols) { + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 + // so that all v0-64 of the pack are dequantized. + __m256i bv_32_4bit_tmp = __lasx_xvld(b_data_ptr + col_ * b_data_col_stride_in_bytes, 0); + if (!count_half_4) + weight_32_epi8[col_] = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + weight_32_epi8[col_] = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + } else { + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_data_ptr + col_ * b_data_col_stride_in_bytes, 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + weight_32_epi8[col_] = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + } + + // 2. load zeropoint and scale + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } + + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_32_epi8[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if (col_ < cols) { + if (i_of_4 == 0) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xE1)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 1) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 2) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xD8)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 3) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; + if constexpr (IsBlkLen64Layout) { + b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; + } else { + b_data_ptr += subblk_data_size_in_bytes; + } + } // subblk + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemm_CompFp32_Lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +) +{ + if (BlkLen == 16) { + Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else if (BlkLen == 32) { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } +} + +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx = []() { + MLAS_QNBIT_GEMM_DISPATCH d; + + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize_Lasx<4>; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData_Lasx; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum_Lasx; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum_Lasx; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_Lasx; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_Lasx; + + return d; +}(); diff --git a/src/lib/sqnbitgemm_kernel_lasx_common.h b/src/lib/sqnbitgemm_kernel_lasx_common.h new file mode 100644 index 0000000..508bcba --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_lasx_common.h @@ -0,0 +1,514 @@ +/*++ + Abstract: + + Lasx/Lsx tool function, Auxiliary functions for inference required by + 4-bit/8-bit quantization models. +--*/ +#pragma once +#include "qnbitgemm.h" +#include "core/common/safeint.h" +#include + +template +struct MlasAlignedAllocator { + using value_type = T; + + MlasAlignedAllocator() = default; + + template + MlasAlignedAllocator(const MlasAlignedAllocator&) {} + + T* allocate(size_t n) { + // If RequiredAlignment > 0, use the required value + // Otherwise, use the value of MlasGetPreferredBufferAlignment() + size_t alignment = RequiredAlignment > 0 ? + RequiredAlignment : + MlasGetPreferredBufferAlignment(); + + size_t size = n * sizeof(T); + if (size % alignment != 0) // check the size + size = ((size + alignment - 1) / alignment) * alignment; + #if defined(_MSC_VER) + void* ptr = _aligned_malloc(size, alignment); + #else + void* ptr = aligned_alloc(alignment, size); + #endif + if (!ptr) throw std::bad_alloc(); + return static_cast(ptr); + } + + void deallocate(T* ptr, size_t) { + #if defined(_MSC_VER) + _aligned_free(ptr); + #else + free(ptr); + #endif + } + + template + struct rebind { + using other = MlasAlignedAllocator; + }; +}; + +static MLAS_FORCEINLINE __m256 +__lasx_xvzero() +{ + return (__m256)__lasx_xvldi(0); +} + +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + SafeInt scale_dst_offset = SafeInt(T) * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += SafeInt(t) * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += SafeInt(k_sub_or_blk) * 4 + t; + } + return scale_dst_offset.Value(); +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + SafeInt scale_dst_offset = SafeInt(T) * 4 * BlockCountK; + if (te) { + scale_dst_offset += SafeInt(t) * BlockCountK + k_blk; + } else { + scale_dst_offset += SafeInt(k_subblk) * blks_per_sub * 4; + if (be) { + scale_dst_offset += SafeInt(b) * 4 + t; + } else { + scale_dst_offset += SafeInt(t) * blks_per_sub + b; + } + } + return scale_dst_offset.Value(); +} + +static void +ComputePackBlkSum_Lasx( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + float QuantBScale = QuantBScaleBegin[src_blk_offset.Value()]; + uint8_t zp = 8; + + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + SafeInt src_zp_offset = SafeInt(ZPCountK) * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset.Value(); + const std::byte low_mask{0X0f}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + float result = -QuantBScale * zp; + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + BlockSumBegin[dst_offset.Value()] = result; + + if (BlkLen == 16) { + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } + }); +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const SafeInt src_data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset.Value(); + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size + ) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0f}) | ((src1 & std::byte{0x0f}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const SafeInt k_blk = SafeInt(k_subblk) * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const SafeInt k_blk = SafeInt(k_subblk) * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +static MLAS_FORCEINLINE __m128 +FoldAccumulators_Lasx(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) +{ + /* + acc0 = [A0, A1, A2, A3, A4, A5, A6, A7] + acc1 = [B0, B1, B2, B3, B4, B5, B6, B7] + */ + + __m256 tmpAB_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0x44), 0xD8); // a1,a2,a5,a6,b1,b2,b5,b6 + __m256 tmpAB_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0xEE), 0xD8); // a3,a4,a7,a8,b3,b4,b7,b8 + __m256 tmpCD_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0x44), 0xD8); // c1,c2,c5,c6,d1,d2,d5,d6 + __m256 tmpCD_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0xEE), 0xD8); // c3,c4,c7,c8,d3,d4,d7,d8 + + __m256 tmpABCD_lo1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0x44); // a1,a2,c1,c2,b1,b2,d1,d2 + __m256 tmpABCD_lo2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0x44); // a3,a4,c3,c4,b3,b4,d3,d4 + __m256 tmpABCD_hi1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0xEE); // a5,a6,c5,c6,b5,b6,d5,d6 + __m256 tmpABCD_hi2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0xEE); // a7,a8,c7,c8,b7,b8,d7,d8 + + __m256 sumABCD = __lasx_xvfadd_s(__lasx_xvfadd_s(tmpABCD_lo1, tmpABCD_lo2), __lasx_xvfadd_s(tmpABCD_hi1, tmpABCD_hi2)); + + __m256 sum0 = (__m256)__lasx_xvpermi_w(sumABCD, sumABCD, 0xB1); + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvfadd_s(sumABCD, sum0), 0xD8); + + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(sumABCD, sumABCD, 0x88), 0xD8); + + alignas(32) float tmp[8]; + __lasx_xvst(sumABCD, (void*)&tmp, 0); + __m128 result = (__m128)__lsx_vld((void*)&tmp, 0); + return result; +} + +__m256 +permutevar_ps_lasx(__m256 vec, __m256i idx_mask) +{ + __m256i veci = (__m256i)vec; + __m256i shuffled = __lasx_xvshuf_w(veci, veci, idx_mask); + return (__m256)shuffled; +} + +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const SafeInt data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkLen; + const std::byte* src = QuantBDataBegin + data_offset.Value(); + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + const SafeInt remain_data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + remain_data_offset.Value(); + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + SafeInt size = SafeInt(N) * BlockCountK; + std::vector> QuantBScaleBeginCopy(size.Value()); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset.Value()]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset.Value(); + zp = (uint8_t)(*QuantBZP); + } + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset.Value()) = -QuantBScale * zp; + + if (n_4 + 4 > N) { + SafeInt ptr_offset = SafeInt(n) * BlockCountK + k_blk; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + SafeInt ptr_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + SafeInt dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = SafeInt(n_4) * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset.Value()) = QuantBScale; + } + }); +} + +static void +Q8PackQuantBDataAndBlkSum_lasx( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8_lasx(__m256 v0, __m256 v1) +{ + // fp32->int32 + __m256i v0_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v0)); + __m256i v1_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v1)); + + alignas(32) int val_0_15_i32[16] = {0}; + alignas(32) int8_t val_0_15_i8[16] = {0}; + + __lasx_xvst(v0_8_epi32, (void*)&val_0_15_i32, 0); + __lasx_xvst(v1_8_epi32, (void*)&val_0_15_i32, 32); + + UnrolledLoop<16>([&](size_t i) { + if (val_0_15_i32[i] > 127) + val_0_15_i8[i] = 127; + else if (val_0_15_i32[i] < -128) + val_0_15_i8[i] = -128; + else + val_0_15_i8[i] = static_cast(val_0_15_i32[i]); + }); + + __m128i result = __lsx_vld((void*)&val_0_15_i8, 0); + return result; +} + +static inline __m256i +lasx_maddubs_epi16_sat(__m256i a, __m256i b) +{ + // a: bytes interpreted as unsigned + // b: bytes interpreted as signed + __m256i zero_h = __lasx_xvldi(0); // 256-bit zeros + + __m256i even_prod16 = __lasx_xvmaddwev_h_bu_b(zero_h, a, b); + __m256i odd_prod16 = __lasx_xvmaddwod_h_bu_b(zero_h, a, b); + + __m256i sum16_sat = __lasx_xvsadd_h(even_prod16, odd_prod16); + + return sum16_sat; // 16-bit signed saturated results (16 lanes) +} + +static inline __m256i +lasx_madd_epi16(__m256i a, __m256i b) +{ + __m256i zero = __lasx_xvldi(0); + __m256i even_acc = __lasx_xvmaddwev_w_h(zero, a, b); + __m256i result = __lasx_xvmaddwod_w_h(even_acc, a, b); + + return result; // 32-bit signed sums, matches _mm256_madd_epi16 semantics (no saturation) +} + +static inline __m256i +lasx_hadd_epi32(__m256i a, __m256i b) +{ + __m256i a_swapped = __lasx_xvshuf4i_w(a, 0xB1); // 0xB1 = binary 10110001 + __m256i b_swapped = __lasx_xvshuf4i_w(b, 0xB1); + + __m256i a_sum = __lasx_xvadd_w(a, a_swapped); + __m256i b_sum = __lasx_xvadd_w(b, b_swapped); + + __m256i a_even = __lasx_xvpermi_w(a_sum, a_sum, 0x88); + __m256i b_even = __lasx_xvpermi_w(b_sum, b_sum, 0x88); + + __m256i result = __lasx_xvpermi_q(a_even, b_even, 0x20); + + return result; +} + +static inline __m256i +lasx_cvtepu8_epi16_emul_from_m128(const __m128i v128) +{ + alignas(32) int8_t num[32] = {0}; + __lsx_vst(v128, (void*)&num, 0); + __m256i result = __lasx_xvld((void*)&num, 0); + result = __lasx_xvexth_hu_bu(__lasx_xvpermi_d(result, 0x72)); + return result; +} + +static MLAS_FORCEINLINE float +hsum_float_8_lasx(__m256 v) +{ + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0xB1)); + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0x4E)); + alignas(32) float num[8] = {0.0f}; + __lasx_xvst(v, (void*)num, 0); + + return num[0] + num[1]; +} diff --git a/src/lib/sqnbitgemm_kernel_neon.cpp b/src/lib/sqnbitgemm_kernel_neon.cpp deleted file mode 100644 index 3f32cc6..0000000 --- a/src/lib/sqnbitgemm_kernel_neon.cpp +++ /dev/null @@ -1,194 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - sqnbitgemm_kernel_neon.cpp - -Abstract: - - This module implements the float/quantized n-bit integer matrix - multiplication kernels for ARM NEON. - ---*/ - -#include - -#include - -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" -#include "sqnbitgemm_q8_block.h" - -namespace sqnbitgemm_neon -{ - -namespace -{ - -// -// Quantized B data packing function implementation. -// - -size_t -SQ4BitGemmPackQuantBDataSize( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - - constexpr size_t BlkBitWidth = 4; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; -} - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - const size_t SubBlkLen = (ComputeType == CompInt8) - ? ((BlkLen == 16) ? 16 : 32) - : 16; - - const size_t SubBlkDataSize = SubBlkLen / 2; - const size_t SubBlkBytePairCount = SubBlkLen / 4; - - // - // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - - // - // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | - // => - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { - for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += SubBlkDataSize; - PackedQuantBData += SubBlkDataSize; - } - } - ); -} - -// -// Workspace size calculation function implementation. -// - -size_t -SQ4BitGemmPerGemmWorkspaceSize( - size_t M, - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(N); - - switch (ComputeType) { - case CompInt8: { - // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - return PerGemmWorkspaceSize; - } - default: { - return 0; - } - } -} - -size_t -SQ4BitGemmPerGemmWorkspaceAlignment( - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType -) -{ - MLAS_UNREFERENCED_PARAMETER(BlkLen); - - switch (ComputeType) { - case CompInt8: { - return Q8BlkAlignment(); - } - default: { - return 1; - } - } -} - -} // namespace - -} // namespace sqnbitgemm_neon - -// -// Kernel dispatch structure definition. -// - -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; - - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; - - d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; - d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; - - return d; -}(); diff --git a/src/lib/sqnbitgemm_kernel_neon_fp32.cpp b/src/lib/sqnbitgemm_kernel_neon_fp32.cpp index 12ddc42..31a499b 100644 --- a/src/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompFp32. --*/ @@ -21,8 +21,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { @@ -31,7 +31,7 @@ namespace { // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // MLAS_FORCEINLINE void @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl( } // namespace void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, diff --git a/src/lib/sqnbitgemm_kernel_neon_int8.cpp b/src/lib/sqnbitgemm_kernel_neon_int8.cpp index 0d62ea3..b03b812 100644 --- a/src/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -1,7 +1,6 @@ /*++ Copyright (c) Microsoft Corporation. All rights reserved. - Licensed under the MIT License. Module Name: @@ -13,23 +12,29 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8. --*/ #include #include +#include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" +#ifdef USE_KLEIDIAI +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_ukernel_interface.h" +#endif + namespace sqnbitgemm_neon { // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // namespace @@ -126,6 +131,41 @@ QuantizeBlock( } // namespace +bool +UsePacked_CompInt8(size_t K, size_t BlkLen, bool HasZp) +{ + return UseKleidiAI(K, BlkLen, HasZp); +} + +#ifdef USE_KLEIDIAI +void +QuantizeA_Packed_CompInt8( + size_t, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA +) +{ + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = + CountM == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t mr = ukernel.get_mr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + const size_t src_stride = CountK * sizeof(float); + const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(0, src_stride); + const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( + 0, CountK, mr, kr, sr); + + const float* src_ptr = reinterpret_cast(reinterpret_cast(A) + lhs_offset); + void* dst_ptr = QuantA + lhs_packed_offset; + + kai_run_lhs_quant_pack_qai8dxp_f32(CountM, CountK, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); +} +#endif + void QuantizeARow_CompInt8( size_t BlkLen, @@ -147,6 +187,230 @@ QuantizeARow_CompInt8( } } +MLAS_FORCEINLINE +float32x4_t LoadFloat32x4(const float* src, size_t count) +{ + if (count == 4) { + return vld1q_f32(src); + } else if (count == 3) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + v = vld1q_lane_f32(src + 2, v, 2); + return v; + } else if (count == 2) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + return v; + } else { + assert(count == 1); + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + return v; + } +} + +template +using I16VecType = typename std::conditional::type; + +template +I16VecType MLAS_FORCEINLINE +PrepareZeroI16() +{ + if constexpr (IsQuantAUnsigned) { + return vdupq_n_u16(0); + } else { + return vdupq_n_s16(0); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +) +{ + // First use i8 to quantize A. range [-128, 127] + // If convert to u8, +128. Range [0, 255] + assert(BlkLen % 16 == 0); + assert(BlkLen <= 256); + MLAS_DECLSPEC_ALIGN(static const uint8_t MASK[16], 16) = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + const int16x8_t v128 = vdupq_n_s16(128); + QuantAType* blob = reinterpret_cast*>(QuantA); + float* scale_ptr = QuantAScale; + size_t k = 0; + for (; k + BlkLen <= CountK; k += BlkLen) { + float32x4_t absMax0 = vdupq_n_f32(0.0f); + float32x4_t absMax1 = vdupq_n_f32(0.0f); + float32x4_t absMax2 = vdupq_n_f32(0.0f); + float32x4_t absMax3 = vdupq_n_f32(0.0f); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4x4_t v0 = vld4q_f32(A + k + kk); + absMax0 = vmaxq_f32(absMax0, vabsq_f32(v0.val[0])); + absMax1 = vmaxq_f32(absMax1, vabsq_f32(v0.val[1])); + absMax2 = vmaxq_f32(absMax2, vabsq_f32(v0.val[2])); + absMax3 = vmaxq_f32(absMax3, vabsq_f32(v0.val[3])); + } + + const float32x4_t max01 = vmaxq_f32(absMax0, absMax1); + const float32x4_t max23 = vmaxq_f32(absMax2, absMax3); + const float32x4_t max0123 = vmaxq_f32(max01, max23); + const float maxScalar = vmaxvq_f32(max0123); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + scale_ptr++; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16_0 = PrepareZeroI16(); + I16VecType sum_8_i16_1 = PrepareZeroI16(); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4_t vfp32_0 = LoadFloat32x4(A + k + kk, 4); + const float32x4_t vfp32_1 = LoadFloat32x4(A + k + kk + 4, 4); + const float32x4_t vfp32_2 = LoadFloat32x4(A + k + kk + 8, 4); + const float32x4_t vfp32_3 = LoadFloat32x4(A + k + kk + 12, 4); + + const float32x4_t v0 = vmulq_f32(vfp32_0, mul); + const float32x4_t v1 = vmulq_f32(vfp32_1, mul); + const float32x4_t v2 = vmulq_f32(vfp32_2, mul); + const float32x4_t v3 = vmulq_f32(vfp32_3, mul); + + const int32x4_t i0 = vcvtnq_s32_f32(v0); + const int32x4_t i1 = vcvtnq_s32_f32(v1); + const int32x4_t i2 = vcvtnq_s32_f32(v2); + const int32x4_t i3 = vcvtnq_s32_f32(v3); + + const int16x8_t v_8_i16_0 = vcombine_s16(vqmovn_s32(i0), vqmovn_s32(i1)); + const int16x8_t v_8_i16_1 = vcombine_s16(vqmovn_s32(i2), vqmovn_s32(i3)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16_0 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_0, v128)); + const uint16x8_t v_8_u16_1 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_1, v128)); + const uint8x16_t v_16_u8 = vcombine_u8(vqmovn_u16(v_8_u16_0), vqmovn_u16(v_8_u16_1)); + vst1q_u8(blob + k + kk, v_16_u8); + + // accumulate Sum(a_i) + const uint16x8_t i_8_u16_0 = vmovl_u8(vget_low_u8(v_16_u8)); + const uint16x8_t i_8_u16_1 = vmovl_high_u8(v_16_u8); + sum_8_i16_0 = vaddq_u16(sum_8_i16_0, i_8_u16_0); + sum_8_i16_1 = vaddq_u16(sum_8_i16_1, i_8_u16_1); + } else { + const int8x16_t v_16_i8 = vcombine_s8(vqmovn_s16(v_8_i16_0), vqmovn_s16(v_8_i16_1)); + vst1q_s8(blob + k + kk, v_16_i8); + + // accumulate Sum(a_i) + sum_8_i16_0 = vaddq_s16(sum_8_i16_0, v_8_i16_0); + sum_8_i16_1 = vaddq_s16(sum_8_i16_1, v_8_i16_1); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t sum_8_u16 = vaddq_u16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_u16(sum_8_u16)); + } else { + const int16x8_t sum_8_i16 = vaddq_s16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + AScaledBlkSum++; + } + + if (k < CountK) { + float32x4_t absMax = vdupq_n_f32(0.0f); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t v0 = LoadFloat32x4(A + kk, step); + absMax = vmaxq_f32(absMax, vabsq_f32(v0)); + } + + const float maxScalar = vmaxvq_f32(absMax); + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16 = PrepareZeroI16(); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t vfp32 = LoadFloat32x4(A + kk, step); + const float32x4_t v_f32 = vmulq_f32(vfp32, mul); + const int32x4_t v_i32 = vcvtnq_s32_f32(v_f32); + const int16x8_t v_8_i16 = vcombine_s16(vqmovn_s32(v_i32), vdup_n_s16(0)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16, v128)); + uint8x8_t v_8_u8 = vqmovn_u16(v_8_u16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_u8(v_8_u8), 0); + + // accumulate Sum(a_i) + v_8_u8 = vand_u8(v_8_u8, vld1_u8(MASK + 8 - step)); + const uint16x8_t i_8_u16 = vmovl_u8(v_8_u8); + sum_8_i16 = vaddq_u16(sum_8_i16, i_8_u16); + } else { + const int8x8_t v_8_i8 = vqmovn_s16(v_8_i16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_s8(v_8_i8), 0); + + // accumulate Sum(a_i) + sum_8_i16 = vaddq_s16(sum_8_i16, v_8_i16); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + qsum = static_cast(vaddvq_u16(sum_8_i16)); + } else { + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + + memset(blob + CountK, 0, BlkLen - (CountK % BlkLen)); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + namespace { @@ -1399,4 +1663,764 @@ SQ4BitGemmKernel_CompInt8( return CountM; } +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + uint32x4_t acc1_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_u32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % MRows2 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + uint32x4_t acc0 = vdupq_n_u32(0U); + uint32x4_t acc1 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + acc1 = vdotq_u32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_u32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + uint32x4_t acc0 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % MRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + +#ifdef USE_KLEIDIAI +void +SQ4BitGemmKernel_Packed_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias +) +{ + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel = + RangeCountM == 1 && RangeStartM == 0? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t dst_stride = ldc * sizeof(float); + + const size_t lhs_packed_offset = ukernel.get_lhs_packed_offset(RangeStartM, CountK); + const size_t rhs_packed_offset = ukernel.get_rhs_packed_offset(RangeStartN, CountK, BlkLen); + const size_t dst_offset = ukernel.get_dst_offset(RangeStartM, RangeStartN, dst_stride); + + const void* lhs_ptr = QuantA + lhs_packed_offset; + const void* rhs_ptr = PackedQuantBData + rhs_packed_offset; + float* dst_ptr = reinterpret_cast(reinterpret_cast(C) + dst_offset); + + ukernel.run_matmul( + RangeCountM, RangeCountN, CountK, BlkLen, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max()); + + if (Bias != nullptr) { + for (size_t m = RangeStartM; m < RangeStartM + RangeCountM; m++) { + for (size_t n = RangeStartN; n < RangeStartN + RangeCountN; n++) { + C[m * ldc + n] += Bias[n]; + } + } + } +} +#endif + } // namespace sqnbitgemm_neon diff --git a/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp b/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp new file mode 100644 index 0000000..db040db --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp @@ -0,0 +1,743 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8_i8mm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + input type T1 as float32 and + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8 + using i8mm instructions. + +--*/ + +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + int32x4_t acc1_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_s32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % NRows2 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + int32x4_t acc0 = vdupq_n_s32(0); + int32x4_t acc1 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + acc1 = vusdotq_s32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_s32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + int32x4_t acc0 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 45c3963..941b884 100644 --- a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index e9c3812..ed78dfa 100644 --- a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/src/lib/sqnbitgemm_q8_block.h b/src/lib/sqnbitgemm_q8_block.h index 80af2f4..9dc217c 100644 --- a/src/lib/sqnbitgemm_q8_block.h +++ b/src/lib/sqnbitgemm_q8_block.h @@ -66,5 +66,5 @@ MLAS_FORCEINLINE constexpr size_t Q8BlkAlignment() { - return alignof(float); + return 16 * alignof(float); } diff --git a/src/lib/sve/elementwise_sve.cpp b/src/lib/sve/elementwise_sve.cpp new file mode 100644 index 0000000..fb2972a --- /dev/null +++ b/src/lib/sve/elementwise_sve.cpp @@ -0,0 +1,683 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + elementwise_sve.cpp + +Abstract: + + This module contains the implementation of SVE-based elementwise operations +--*/ + +#include "mlasi_sve.h" +#include + +// +// Bundles the constants for use by kernels written in assembly. +// + +MLAS_INTERNAL_DATA const struct { + float ErfUpperAbsRange; + float ErfSplitBoundary; + float ErfSMALL_P0; + float ErfSMALL_P1; + float ErfSMALL_P2; + float ErfSMALL_P3; + float ErfSMALL_P4; + float ErfSMALL_P5_Minus_One; + float ErfReserved0; + float ErfBIG_P0; + float ErfBIG_P1; + float ErfBIG_P2; + float ErfBIG_P3; + float ErfBIG_P4; + float ErfBIG_P5; + float ErfBIG_P6_Minus_One; + float ErfNegZero; + float ErfOne; + + float Exp_UpperRange; + float Exp_LowerRange; + float Exp_Log2Reciprocal; + float Exp_log2_hi; + float Exp_log2_lo; + float Exp_P0; + float Exp_P1; + float Exp_P2; + float Exp_P3; + float Exp_P4; + float Exp_P5; + float Exp_P6; + float Exp_C; + int32_t Exp_X7F; +} MlasSveErfConstants = { + 3.925f, + 0.921875f, + -5.99104969e-4f, + 4.99339588e-3f, + -2.67667342e-2f, + 1.12818025e-1f, + -3.76124859e-1f, + 1.28379151e-1f, + 0.0f, + 1.72948930e-5f, + -3.83208680e-4f, + 3.88393435e-3f, + -2.42545605e-2f, + 1.06777847e-1f, + 6.34846687e-1f, + 1.28717512e-1f, + -0.0f, + 1.0f, + + // Independent parameters to calculate Exp for Erff() + 88.3762626647950f, + -88.3762626647949f, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 1.38319808e-3f, + 8.37550033e-3f, + 4.16689515e-2f, + 1.66664466e-1f, + 4.99999851e-1f, + 1.00000000e+0f, + 1.00000000e+0f, + 1.25829120e+7f, + 127, +}; + +MLAS_INTERNAL_DATA const struct { + float LowerRange; + float UpperRange; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; +} MlasSveLogisticConstants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, +}; + +MLAS_INTERNAL_DATA const struct { + float LowerRange; + float UpperRange; + float LowerRangeSumExp; + float UpperRangeSumExp; + float RoundingBias; + float Log2Reciprocal; + float Log2High; + float Log2Low; + float poly_0; + float poly_1; + float poly_2; + float poly_3; + float poly_4; + float poly_56; + int32_t MinimumExponent; + int32_t MaximumExponent; +} MlasSveExpConstants = { + -103.9720840454f, + 88.7762626647950f, + -88.3762626647949f, + 88.3762626647949f, + MLAS_ROUNDING_BIAS_MAGIC, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 0x1.694000p-10, + 0x1.125edcp-7, + 0x1.555b5ap-5, + 0x1.555450p-3, + 0x1.fffff6p-2, + 0x1.000000p+0, + int32_t(0xC1000000), + int32_t(0x3F800000), +}; + +MLAS_INTERNAL_DATA const float MlasSveMinimumF32Value = std::numeric_limits::lowest(); + +void +MLASCALL +MlasSveErfKernel( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine implements the generic kernel for the error function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t sve_veclen = svcntw(); + size_t stride = sve_veclen; + + while (N > 0) { + // If fewer that SVE vector length elements are remaining, adjust the predicate + if (N < sve_veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Value = MlasSveLoadFloat32(Pred, Input); + MLAS_SVFLOAT32 NegZero = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfNegZero); + MLAS_SVFLOAT32 SignMask = MlasSveAndFloat32(Pred, Value, NegZero); + MLAS_SVFLOAT32 AbsValue = MlasSveAndNotFloat32(Pred, NegZero, Value); + AbsValue = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfUpperAbsRange), AbsValue); + MLAS_SVFLOAT32 SquareValue = MlasSveMultiplyFloat32(Pred, AbsValue, AbsValue); + + MLAS_SVFLOAT32 r_small = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P0); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P1)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P2)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P3)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P4)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P5_Minus_One)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, AbsValue, AbsValue); + MLAS_SVFLOAT32 split_mask = MlasSveGreaterThanFloat32(Pred, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSplitBoundary)); + r_small = MlasSveAndNotFloat32(Pred, split_mask, r_small); + + AbsValue = MlasSveAndFloat32(Pred, split_mask, AbsValue); + MLAS_SVFLOAT32 r_big = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P0); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P1)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P2)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P3)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P4)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P5)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P6_Minus_One)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, AbsValue); + + r_big = MlasSveXorFloat32(Pred, r_big, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfNegZero)); + r_big = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_LowerRange), r_big); + MLAS_SVFLOAT32 exp_c = MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_C); + MLAS_SVFLOAT32 r = MlasSveMultiplyAddFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_Log2Reciprocal), r_big, exp_c); + r = MlasSveSubtractFloat32(Pred, r, exp_c); + + MLAS_SVFLOAT32 fx = MlasSveMultiplyAddFloat32(Pred, r, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_log2_hi), r_big); + fx = MlasSveMultiplyAddFloat32(Pred, r, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_log2_lo), fx); + + MLAS_SVFLOAT32 y = MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P0); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P1)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P2)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P3)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P4)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P5)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P6)); + + y = MlasSveMultiplyFloat32(Pred, y, MlasSvePowerOf2Float32(Pred, r)); + y = MlasSveSubtractFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfOne), y); + + y = MlasSveOrFloat32(Pred, r_small, y); + y = MlasSveOrFloat32(Pred, y, SignMask); + MlasSveStoreFloat32(Pred, Output, y); + + Input += stride; + Output += stride; + N -= stride; + } +} + +void +MLASCALL +MlasSveLogisticKernel( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine implements the generic kernel for the logistic function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t sve_veclen = svcntw(); + size_t stride = sve_veclen; + + while (N > 0) { + // If fewer that SVE vector length elements are remaining, adjust the predicate + if (N < sve_veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Value = MlasSveLoadFloat32(Pred, Input); + + Value = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveLogisticConstants.LowerRange), Value); + Value = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveLogisticConstants.UpperRange), Value); + + MLAS_SVFLOAT32 ValueSquared = MlasSveMultiplyFloat32(Pred, Value, Value); + + MLAS_SVFLOAT32 p; + p = MlasSveMultiplyAddFloat32( + Pred, + ValueSquared, + MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_9), + MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_7) + ); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_5)); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_3)); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_1)); + p = MlasSveMultiplyFloat32(Pred, p, Value); + + MLAS_SVFLOAT32 q; + q = MlasSveMultiplyAddFloat32( + Pred, + ValueSquared, + MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_10), + MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_8) + ); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_6)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_4)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_2)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_0)); + + MlasSveStoreFloat32( + Pred, + Output, + MlasSveClampFloat32(Pred,MlasSveAddFloat32( + Pred, + MlasSveDivideFloat32(Pred, p, q), + MlasSveBroadcastFloat32(0.5f) + ), 0.0f, 1.0f) + ); + + Input += stride; + Output += stride; + N -= stride; + } +} + +/* +SVE implementation of expf() using a polynomial approximation is taken from ARM Compute Library Repository. +https://github.com/ARM-software/ComputeLibrary/blob/9f7a1fb06bc0435d989a9a6a3c0fd2cebfedbf5f/src/core/NEON/SVEMath.inl#L105 +*/ +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector +) +{ + const uint32_t svexp_f32_coeff[] = { + 0x3f7ffff6, // x^1: 0x1.ffffecp-1f + 0x3efffedb, // x^2: 0x1.fffdb6p-2f + 0x3e2aaf33, // x^3: 0x1.555e66p-3f + 0x3d2b9f17, // x^4: 0x1.573e2ep-5f + 0x3c072010, // x^5: 0x1.0e4020p-7f + }; + + const auto c1 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[0])); + const auto c2 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[1])); + const auto c3 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[2])); + const auto c4 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[3])); + const auto c5 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[4])); + + const auto shift = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + + const auto inf = MlasSveBroadcastFloat32(std::numeric_limits::infinity()); + const auto max_input = MlasSveBroadcastFloat32(88.37f); // Approximately ln(2^127.5) + const auto zero = MlasSveZeroFloat32(); + const auto min_input = MlasSveBroadcastFloat32(-86.64f); // Approximately ln(2^-125) + + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = MlasSveMultiplyAddFloat32(Pred, Vector, inv_ln2, shift); + const auto n = MlasSveSubtractFloat32(Pred, z, shift); + const auto scale = MlasSveReinterpretAsFLOAT32(MlasSveShiftLeftUInt32<23>(Pred, MlasSveReinterpretAsUInt32(z))); // 2^n + + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = MlasSveMultiplyAddFloat32(Pred, n, neg_ln2_hi, Vector); + const auto r = MlasSveMultiplyAddFloat32(Pred, n, neg_ln2_lo, r_hi); + + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = MlasSveMultiplyFloat32(Pred, r, r); + + const auto p1 = MlasSveMultiplyFloat32(Pred, c1, r); + const auto p23 = MlasSveMultiplyAddFloat32(Pred, c3, r, c2); + const auto p45 = MlasSveMultiplyAddFloat32(Pred, c5, r, c4); + const auto p2345 = MlasSveMultiplyAddFloat32(Pred, p45, r2, p23); + const auto p12345 = MlasSveMultiplyAddFloat32(Pred, p2345, r2, p1); + + auto poly = MlasSveMultiplyAddFloat32(Pred, p12345, scale, scale); + + // Handle underflow and overflow. + poly = MlasSveSelect(MlasSveCompareLessThan(Pred, Vector, min_input), zero, poly); + poly = MlasSveSelect(MlasSveCompareGreaterThan(Pred, Vector, max_input), inf, poly); + + return poly; +} + +void +MLASCALL +MlasSveComputeExpF32Kernel( + const float* Input, + float* Output, + size_t N +) +{ + const size_t veclen = svcntw(); + + // Fast path: Use scalar loop when N is 1 + if (N == 1) { + Output[0] = expf(Input[0]); + return; + } + + // Vectorized path + MLAS_SVBOOL Pred = svptrue_b32(); + size_t stride = veclen; + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveComputeExpVector(Pred, Vector); + MlasSveStoreFloat32(Pred, Output, Vector); + + Input += stride; + Output += stride; + N -= stride; + } +} + +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeSumExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector, + MLAS_SVFLOAT32 NegativeMaximumVector +) +{ + Vector = MlasSveAddFloat32(Pred, Vector, NegativeMaximumVector); + Vector = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveExpConstants.LowerRangeSumExp), Vector); + + const auto RoundingBias = MlasSveBroadcastFloat32(MlasSveExpConstants.RoundingBias); + auto biased = MlasSveMultiplyAddFloat32(Pred, Vector, MlasSveExpConstants.Log2Reciprocal, RoundingBias); + auto m = MlasSveSubtractFloat32(Pred, biased, RoundingBias); + + Vector = MlasSveMultiplyAddFloat32(Pred, m, MlasSveExpConstants.Log2High, Vector); + Vector = MlasSveMultiplyAddFloat32(Pred, m, MlasSveExpConstants.Log2Low, Vector); + + auto normal = MlasSveShiftLeftInt32<23>(Pred, MlasSveReinterpretAsInt32(biased)); + normal = MlasSveAddInt32(Pred, normal, MlasSveBroadcastInt32(MlasSveExpConstants.MaximumExponent)); + + auto p = MlasSveBroadcastFloat32(MlasSveExpConstants.poly_0); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_1); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_2); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_3); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_4); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_56); // <--| + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_56); // Twice? + + p = MlasSveMultiplyFloat32(Pred, p, MlasSveReinterpretAsFloat32(normal)); + return p; +} + +float +MLASCALL +MlasSveComputeSumExpF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* NegativeMaximum +) +/** + * Potential optimization: Consider applying loop unrolling to improve instruction-level + * parallelism (ILP) in this kernel. Evaluate the performance impact using benchmarks + * before and after implementing the optimization. + */ +{ + if (N == 1) { + float result = expf(Input[0] + *NegativeMaximum); + if (Output != nullptr) { + Output[0] = result; + } + return result; + } + + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + float sum = 0.0f; + + MLAS_SVFLOAT32 NegativeMaximumVector = MlasSveBroadcastFloat32(*NegativeMaximum); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveComputeSumExpVector(Pred, Vector, NegativeMaximumVector); + + if (Output != nullptr) { + MlasSveStoreFloat32(Pred, Output, Vector); + Output += stride; + } + + sum += MlasSveReduceAddFloat32(Pred, Vector); + + Input += stride; + N -= stride; + } + return sum; +} + +float MLASCALL +MlasSveReduceMaximumF32Kernel( + const float* Input, + size_t N +) +{ + size_t veclen = svcntw(); + MLAS_SVBOOL Pred = svptrue_b32(); + + float Maximum; + MLAS_SVFLOAT32 MaximumVector0 = MlasSveBroadcastFloat32(MlasSveMinimumF32Value); + + if (N >= veclen * 4) { + MLAS_SVFLOAT32 MaximumVector1 = MaximumVector0; + MLAS_SVFLOAT32 MaximumVector2 = MaximumVector0; + MLAS_SVFLOAT32 MaximumVector3 = MaximumVector0; + + while (N >= veclen * 4) { + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MlasSveLoadFloat32(Pred, Input)); + MaximumVector1 = MlasSveMaximumFloat32(Pred, MaximumVector1, MlasSveLoadFloat32(Pred, Input + veclen)); + MaximumVector2 = MlasSveMaximumFloat32(Pred, MaximumVector2, MlasSveLoadFloat32(Pred, Input + 2 * veclen)); + MaximumVector3 = MlasSveMaximumFloat32(Pred, MaximumVector3, MlasSveLoadFloat32(Pred, Input + 3 * veclen)); + + Input += veclen * 4; + N -= veclen * 4; + } + + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MaximumVector1); + MaximumVector2 = MlasSveMaximumFloat32(Pred, MaximumVector2, MaximumVector3); + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MaximumVector2); + } + size_t stride = veclen; + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, Vector); + + Input += stride; + N -= stride; + } + + Maximum = MlasSveReduceMaximumFloat32(svptrue_b32(), MaximumVector0); + return Maximum; +} + +void +MLASCALL +MlasSveReduceMinimumMaximumF32Kernel( + const float* Input, + float* Min, + float* Max, + size_t N +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + float tmp_min = std::numeric_limits::max(); + float tmp_max = std::numeric_limits::lowest(); + + MLAS_SVFLOAT32 MaximumVector = MlasSveBroadcastFloat32(tmp_max); + MLAS_SVFLOAT32 MinimumVector = MlasSveBroadcastFloat32(tmp_min); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + MaximumVector = MlasSveMaximumFloat32(Pred, MaximumVector, Vector); + MinimumVector = MlasSveMinimumFloat32(Pred, MinimumVector, Vector); + + Input += stride; + N -= stride; + } + *Min = MlasSveReduceMinimumFloat32(svptrue_b32(), MinimumVector); + *Max = MlasSveReduceMaximumFloat32(svptrue_b32(), MaximumVector); +} + +void +MLASCALL +MlasSveComputeSoftmaxOutputF32Kernel( + float* Output, + size_t N, + const float* Parameters +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + const float Scale = Parameters[0]; + const MLAS_SVFLOAT32 ScaleVector = MlasSveBroadcastFloat32(Scale); + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveMultiplyFloat32(Pred, ScaleVector, MlasSveLoadFloat32(Pred, Output)); + MlasSveStoreFloat32(Pred, Output, Vector); + + Output += stride; + N -= stride; + } +} + +void +MLASCALL +MlasSveComputeLogSoftmaxOutputF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* Parameters +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + const float NegativeMaximum = Parameters[0]; + const float Logarithm = Parameters[1]; + MLAS_SVFLOAT32 NegativeMaximumVector = MlasSveBroadcastFloat32(NegativeMaximum); + MLAS_SVFLOAT32 LogarithmVector = MlasSveBroadcastFloat32(Logarithm); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveAddFloat32(Pred, Vector, NegativeMaximumVector); + Vector = MlasSveSubtractFloat32(Pred, Vector, LogarithmVector); + MlasSveStoreFloat32(Pred, Output, Vector); + + Input += stride; + Output += stride; + N -= stride; + } + +} diff --git a/src/lib/sve/mlasi_sve.h b/src/lib/sve/mlasi_sve.h new file mode 100644 index 0000000..67a4bf4 --- /dev/null +++ b/src/lib/sve/mlasi_sve.h @@ -0,0 +1,653 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + mlasi_sve.h + +Abstract: + + This module contains the procedure prototypes for the SVE intrinsics. + +--*/ + +#pragma once + +#include "../mlasi.h" +#include // SVE intrinsic header + +#ifndef __clang__ +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve") + +// Use Clang-specific per-function attribute +#ifdef __clang__ +#define MLAS_SVE_TARGET __attribute__((target("arch=armv8.2-a+sve"))) +#else +#define MLAS_SVE_TARGET +#endif + +typedef svfloat32_t MLAS_SVFLOAT32; +typedef svint32_t MLAS_SVINT32; +typedef svuint32_t MLAS_SVUINT32; +typedef svbool_t MLAS_SVBOOL; + +// function decarations +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector +); + +void +MLASCALL +MlasSveComputeExpF32Kernel( + const float* Input, + float* Output, + size_t N +); + +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeSumExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector, + MLAS_SVFLOAT32 NegativeMaximumVector +); + +float +MLASCALL +MlasSveComputeSumExpF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* NegativeMaximum +); + +float MLASCALL +MlasSveReduceMaximumF32Kernel( + const float* Input, + size_t N +); + +void +MLASCALL +MlasSveReduceMinimumMaximumF32Kernel( + const float* Input, + float* Min, + float* Max, + size_t N +); + +void +MLASCALL +MlasSveComputeSoftmaxOutputF32Kernel( + float* Output, + size_t N, + const float* Parameters +); + +void +MLASCALL +MlasSveComputeLogSoftmaxOutputF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* Parameters +); + +void +MLASCALL +MlasSveErfKernel( + const float* Input, + float* Output, + size_t N +); + +void +MLASCALL +MlasSveLogisticKernel( + const float* Input, + float* Output, + size_t N +); + +//MLAS API for SVE intrinsics + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveReinterpretAsInt32(MLAS_SVFLOAT32 Vector) +{ + return svreinterpret_s32_f32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveReinterpretAsUInt32(MLAS_SVFLOAT32 Vector) +{ + return svreinterpret_u32_f32(Vector); +} + +// Reinterprets an unsigned 32-bit vector as a 32-bit floating-point vector. +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveReinterpretAsFLOAT32(MLAS_SVUINT32 Vector) +{ + return svreinterpret_f32_u32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveCastToInt32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svcvt_s32_f32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveCastToFloat32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) +{ + return svcvt_f32_s32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveBroadcastInt32(int32_t Value) +{ + return svdup_n_s32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveLoadInt32(MLAS_SVBOOL Pred, const int32_t* Buffer) +{ + return svld1_s32(Pred, Buffer); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreInt32(MLAS_SVBOOL Pred, int32_t* Buffer, MLAS_SVINT32 Vector) +{ + svst1_s32(Pred, Buffer, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAddInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svadd_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveSubtractInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svsub_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAndInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svand_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveAndUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector1, MLAS_SVUINT32 Vector2) +{ + return svand_u32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveOrInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svorr_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAndNotInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 VectorNot, MLAS_SVINT32 Vector) +{ + return svand_s32_m(Pred, svnot_s32_z(Pred, VectorNot), Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveXorInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return sveor_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveBlendInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2, MLAS_SVINT32 Selection) +{ + return MlasSveOrInt32( + Pred, + MlasSveAndInt32(Pred, Vector2, Selection), + MlasSveAndNotInt32(Pred, Selection, Vector1) + ); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector) +{ + return svlsl_n_u32_z(Pred, Vector, ShiftCount); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) +{ + return svlsl_n_s32_z(Pred, Vector, ShiftCount); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveShiftRightInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector, uint ShiftCount) +{ + return svlsr_n_u32_m(Pred, Vector, ShiftCount); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveMaximumInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svmax_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveMinimumInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svmin_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveReinterpretAsFloat32(MLAS_SVINT32 Vector) +{ + return svreinterpret_f32_s32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBroadcastFloat32(float Value) +{ + return svdup_n_f32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveBroadcastUINT32(uint Value) +{ + return svdup_n_u32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBroadcastFloat32(const float* Value) +{ + return svld1_f32(svptrue_b32(), Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveZeroFloat32(void) +{ + return svdup_n_f32(0.0f); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveLoadFloat32(MLAS_SVBOOL Pred, const float* Buffer) +{ + return svld1_f32(Pred, Buffer); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreFloat32(MLAS_SVBOOL Pred, float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svst1_f32(Pred, Buffer, Vector); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); + svst1_f32(Pred, Buffer, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreLowHalfFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svbool_t Pred = svwhilelt_b32(0, (int32_t)svcntw() / 2); + svst1_f32(Pred, Buffer, Vector); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector) +{ + float TmpBuffer[1]; + svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); + svst1_f32(Pred, TmpBuffer, Vector); + return TmpBuffer[0]; +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveInterleaveLowFloat32(MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svzip1_f32(Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveInterleaveHighFloat32(MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svzip2_f32(Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svadd_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveSubtractFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svsub_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmul_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveExpFloat32(MLAS_SVUINT32 Vector) +{ + return svexpa_f32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svscale_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveRoundINTFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svrintm_f32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Vector3) +{ + return svmla_f32_m(Pred, Vector3, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, float Scalar2, MLAS_SVFLOAT32 Vector3) +{ + return MlasSveMultiplyAddFloat32(Pred, Vector1, MlasSveBroadcastFloat32(Scalar2), Vector3); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, float Scalar3) +{ + return MlasSveMultiplyAddFloat32(Pred, Vector1, Vector2, MlasSveBroadcastFloat32(Scalar3)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveDivideFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svdiv_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveGreaterThanFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + // Compare Vector1 and Vector2, return a predicate vector + svbool_t cmp_mask = svcmpgt_f32(Pred, Vector1, Vector2); + + //Convert predicate to uint32_t mask + svuint32_t mask_bits = svdup_u32_z(cmp_mask, 0xFFFFFFFF); + + //Reinterpret to float32 + return svreinterpret_f32_u32(mask_bits); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAndFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveAndInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveOrFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveOrInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAndNotFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveAndNotInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveXorFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveXorInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBlendFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Selection) +{ + return MlasSveOrFloat32( + Pred, + MlasSveAndFloat32(Pred, Vector2, Selection), + MlasSveAndFloat32(Pred, Vector1, Selection) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMaximumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmax_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMinimumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmin_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveClampFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Value, float LowerRange, float UpperRange) +{ + Value = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(LowerRange), Value); + Value = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(UpperRange), Value); + return Value; +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svaddv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceMaximumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svmaxv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceMinimumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svminv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSvePowerOf2Float32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + MLAS_SVINT32 emm0 = MlasSveAddInt32( + Pred, + MlasSveCastToInt32(Pred, Vector), + MlasSveBroadcastInt32(127) + ); + return MlasSveReinterpretAsFloat32(MlasSveShiftLeftInt32<23>(Pred, emm0)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveSelect(svbool_t Pred, MLAS_SVFLOAT32 TrueValue, MLAS_SVFLOAT32 FalseValue) +{ + return svsel_f32(Pred, TrueValue, FalseValue); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveCompareLessThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) +{ + return svcmplt_f32(Pred, A, B); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) +{ + return svcmpgt_f32(Pred, A, B); +} + +// GCC: Pop options after SVE-specific functions +#ifndef __clang__ +#pragma GCC pop_options +#endif + +#endif + diff --git a/src/lib/tanh.cpp b/src/lib/tanh.cpp index 9750337..63bd744 100644 --- a/src/lib/tanh.cpp +++ b/src/lib/tanh.cpp @@ -21,6 +21,7 @@ Module Name: --*/ #include "mlasi.h" +#include "softmax.h" // // Bundles the floating point constants for use by kernels written in assembly. @@ -119,9 +120,9 @@ Return Value: float Value = *Input++; - // This odd two-step process exists to ensure an input value of NaN carries through - // without modification because "std::min" and "std::max" return unreliable results - // when NaNs are involved, and it's clear from the test's reference outputs that + // This odd two-step process exists to ensure an input value of NaN carries through + // without modification because "std::min" and "std::max" return unreliable results + // when NaNs are involved, and it's clear from the test's reference outputs that // they want a NaN on output whenever the input is a NaN. float v_tmp; v_tmp = (Value < MlasTanhConstants.LowerRange) ? MlasTanhConstants.LowerRange : Value; @@ -149,9 +150,10 @@ Return Value: } } +template <> void MLASCALL -MlasComputeTanh( +MlasComputeTanh( const float* Input, float* Output, size_t N @@ -182,3 +184,50 @@ Return Value: MlasTanhKernel(Input, Output, N); #endif } + +template <> +void +MLASCALL +MlasComputeTanh( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Tanh_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Tanh_Fp16 is not supported."); + } + dispatch->Tanh_Fp16(Input, Output, N); +} + +template <> +void +MLASCALL +MlasComputeSoftcap( + const float* Input, + float* Output, + size_t N, + float cap +) { + for (size_t i = 0; i < N; i++) { + Output[i] = Input[i] / cap; + Output[i] = std::tanh(Output[i]); + Output[i] = Output[i] * cap; + } +} + +template <> +void +MLASCALL +MlasComputeSoftcap( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + MLAS_FP16 cap +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Softcap_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Softcap_Fp16 is not supported."); + } + dispatch->Softcap_Fp16(Input, Output, N, cap); +} diff --git a/src/lib/transpose.cpp b/src/lib/transpose.cpp index a758a0e..0c471e8 100644 --- a/src/lib/transpose.cpp +++ b/src/lib/transpose.cpp @@ -16,6 +16,20 @@ Module Name: #include "mlasi.h" +// +// Define the parameters to execute segments of a transpose operation on worker +// threads. +// + +template +struct MLAS_TRANPOSE_WORK_BLOCK { + ptrdiff_t ThreadCountM; + const ElementType* Input; + ElementType* Output; + size_t M; + size_t N; +}; + #if defined(MLAS_SSE2_INTRINSICS) MLAS_FORCEINLINE @@ -371,6 +385,180 @@ MlasTranspose16x16Block( vec_vsx_st(e0, 0, &Output[OutputStride * 14]); vec_vsx_st(e1, 0, &Output[OutputStride * 15]); } +#elif defined(MLAS_TARGET_S390X) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + + __vector unsigned int a0 = vec_xl(0, Input); + __vector unsigned int a1 = vec_xl(0, &Input[InputStride]); + __vector unsigned int a2 = vec_xl(0, &Input[InputStride * 2]); + __vector unsigned int a3 = vec_xl(0, &Input[InputStride * 3]); + + __vector unsigned int b0 = vec_mergeh(a0, a1); + __vector unsigned int b1 = vec_mergeh(a2, a3); + __vector unsigned int b2 = vec_mergel(a0, a1); + __vector unsigned int b3 = vec_mergel(a2, a3); + + __vector unsigned int c0 = vec_perm(b0, b1, mask0); + __vector unsigned int c1 = vec_perm(b0, b1, mask3); + __vector unsigned int c2 = vec_perm(b2, b3, mask0); + __vector unsigned int c3 = vec_perm(b2, b3, mask3); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(c0); + MLAS_UNREFERENCED_PARAMETER(c1); + MLAS_UNREFERENCED_PARAMETER(c2); + MLAS_UNREFERENCED_PARAMETER(c3); + + vec_xst(c0, 0, Output); + vec_xst(c1, 0, &Output[OutputStride]); + vec_xst(c2, 0, &Output[OutputStride * 2]); + vec_xst(c3, 0, &Output[OutputStride * 3]); +} + +MLAS_FORCEINLINE +void +MlasTranspose16x16Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + + __vector unsigned char a0 = vec_xl(0, Input); + __vector unsigned char a1 = vec_xl(0, &Input[InputStride]); + __vector unsigned char a2 = vec_xl(0, &Input[InputStride * 2]); + __vector unsigned char a3 = vec_xl(0, &Input[InputStride * 3]); + __vector unsigned char a4 = vec_xl(0, &Input[InputStride * 4]); + __vector unsigned char a5 = vec_xl(0, &Input[InputStride * 5]); + __vector unsigned char a6 = vec_xl(0, &Input[InputStride * 6]); + __vector unsigned char a7 = vec_xl(0, &Input[InputStride * 7]); + __vector unsigned char a8 = vec_xl(0, &Input[InputStride * 8]); + __vector unsigned char a9 = vec_xl(0, &Input[InputStride * 9]); + __vector unsigned char a10 = vec_xl(0, &Input[InputStride * 10]); + __vector unsigned char a11 = vec_xl(0, &Input[InputStride * 11]); + __vector unsigned char a12 = vec_xl(0, &Input[InputStride * 12]); + __vector unsigned char a13 = vec_xl(0, &Input[InputStride * 13]); + __vector unsigned char a14 = vec_xl(0, &Input[InputStride * 14]); + __vector unsigned char a15 = vec_xl(0, &Input[InputStride * 15]); + + __vector unsigned char b0 = vec_mergeh(a0, a1); + __vector unsigned char b1 = vec_mergeh(a2, a3); + __vector unsigned char b2 = vec_mergeh(a4, a5); + __vector unsigned char b3 = vec_mergeh(a6, a7); + __vector unsigned char b4 = vec_mergeh(a8, a9); + __vector unsigned char b5 = vec_mergeh(a10, a11); + __vector unsigned char b6 = vec_mergeh(a12, a13); + __vector unsigned char b7 = vec_mergeh(a14, a15); + __vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + __vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + __vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + __vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(c0); + MLAS_UNREFERENCED_PARAMETER(c1); + MLAS_UNREFERENCED_PARAMETER(c2); + MLAS_UNREFERENCED_PARAMETER(c3); + + __vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + __vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + __vector unsigned char e0 = vec_perm(d0, d1, mask0); + __vector unsigned char e1 = vec_perm(d0, d1, mask3); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(e0); + MLAS_UNREFERENCED_PARAMETER(e1); + + vec_xst(e0, 0, &Output[0]); + vec_xst(e1, 0, &Output[OutputStride]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 2]); + vec_xst(e1, 0, &Output[OutputStride * 3]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 4]); + vec_xst(e1, 0, &Output[OutputStride * 5]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 6]); + vec_xst(e1, 0, &Output[OutputStride * 7]); + + b0 = vec_mergel(a0, a1); + b1 = vec_mergel(a2, a3); + b2 = vec_mergel(a4, a5); + b3 = vec_mergel(a6, a7); + b4 = vec_mergel(a8, a9); + b5 = vec_mergel(a10, a11); + b6 = vec_mergel(a12, a13); + b7 = vec_mergel(a14, a15); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 8]); + vec_xst(e1, 0, &Output[OutputStride * 9]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 10]); + vec_xst(e1, 0, &Output[OutputStride * 11]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 12]); + vec_xst(e1, 0, &Output[OutputStride * 13]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 14]); + vec_xst(e1, 0, &Output[OutputStride * 15]); +} #elif defined(MLAS_LSX_INTRINSICS) @@ -470,20 +658,20 @@ MlasTranspose8x8Block( __m128i c3 = __lsx_vilvh_h(b3, b2); __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 0], 0, 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 1], 0, 1); __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 2], 0, 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 3], 0, 1); __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 4], 0, 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 5], 0, 1); __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 6], 0, 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 7], 0, 1); } #endif @@ -509,7 +697,7 @@ MlasTranspose4xNVector( Output[OutputStride * 3] = a3; } -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template MLAS_FORCEINLINE void @@ -541,54 +729,72 @@ MlasTranspose8xNVector( MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); } +template void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ) +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId +); /*++ Routine Description: - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). + This routine is invoked from a worker thread to execute a segment of a transpose Arguments: - Input - Supplies the input buffer. + Context - Supplies the pointer to the context for the threaded operation. - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. - - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + ThreadId - Supplies the current index of the threaded operation. Return Value: None. --*/ + +template<> +void +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId + ) { - size_t n = N; + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; + + // + // Partition the operation along the M dimension. + // + + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); + + // + // Set transpose parameters. + // + + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; + + const uint32_t* Input = WorkBlock->Input + IndexM * N; + uint32_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_LSX_INTRINSICS) + defined(MLAS_TARGET_S390X) || defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -624,7 +830,7 @@ Return Value: const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -650,68 +856,45 @@ Return Value: } } +template<> void -MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) { - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, - N); -} - - -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + // + // Partition the operation along the M dimension. + // - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); -Return Value: + // + // Set transpose parameters. + // - None. + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; ---*/ -{ - size_t n = N; + const uint16_t* Input = WorkBlock->Input + IndexM * N; + uint16_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -749,7 +932,7 @@ Return Value: const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -775,52 +958,46 @@ Return Value: } } - +template<> void -MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. +{ + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; - Output - Supplies the output buffer. + // + // Partition the operation along the M dimension. + // - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + // + // Set transpose parameters. + // -Return Value: + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; - None. - ---*/ -{ - size_t n = N; + const uint8_t* Input = WorkBlock->Input + IndexM * N; + uint8_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 8 columns // at a time. // -#if defined(MLAS_TARGET_POWER) + + size_t n = N; + +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) while (n >= 16) { const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 16) { MlasTranspose16x16Block(s, N, d, M); @@ -848,7 +1025,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -886,7 +1063,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 8) { @@ -912,17 +1089,140 @@ Return Value: } } +template void MLASCALL MlasTranspose( + const DataType* Input, + DataType* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +/*++ + +Routine Description: + + This routine transposes the input matrix (M rows by N columns) to the + output matrix (N rows by M columns). + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + M - Supplies the number of rows for the input matrix and the number of + columns for the output matrix. + + N - Supplies the number of columns for the input matrix and the number of + rows for the output matrix. + + ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + +Return Value: + + None. + +--*/ +{ + MLAS_TRANPOSE_WORK_BLOCK WorkBlock; + + // + // Capture the transpose parameters to the work block. + // + + WorkBlock.Input = Input; + WorkBlock.Output = Output; + WorkBlock.M = M; + WorkBlock.N = N; + + // + // Compute the number of target threads given the complexity of the transpose + // operation. Limit the number of threads to the number of rows and try to + // keep each thread processing a minimum number of elements before using + // another thread. + // + + ptrdiff_t ThreadCountM = MlasGetMaximumThreadCount(ThreadPool); + + if (size_t(ThreadCountM) > M) { + ThreadCountM = ptrdiff_t(M); + } + + WorkBlock.ThreadCountM = ThreadCountM; + + MlasExecuteThreaded(MlasTransposeThreaded, &WorkBlock, ThreadCountM, ThreadPool); +} + +template +void +MLASCALL +MlasTranspose( + const uint32_t* Input, + uint32_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint16_t* Input, + uint16_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint8_t* Input, + uint8_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template<> +void +MLASCALL +MlasTranspose( const int8_t* Input, int8_t* Output, size_t M, - size_t N) + size_t N, + MLAS_THREADPOOL* ThreadPool + ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), M, - N); + N, + ThreadPool); +} + +template<> +void +MLASCALL +MlasTranspose( + const float* Input, + float* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasTranspose( + reinterpret_cast(Input), + reinterpret_cast(Output), + M, + N, + ThreadPool); } diff --git a/src/lib/x86_64/ConvSymKernelAvx2.S b/src/lib/x86_64/ConvSymKernelAvx2.S index 3004599..194f210 100644 --- a/src/lib/x86_64/ConvSymKernelAvx2.S +++ b/src/lib/x86_64/ConvSymKernelAvx2.S @@ -23,6 +23,91 @@ Abstract: .intel_syntax noprefix + .extern CheckSaturationForVPMADDUBSW + + .macro CheckSaturation VecReg1Num, VecReg2Num + +// +// Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + push rax + push rcx + push rdx + push rsi + push rdi + push r8 + push r9 + push r10 + push r11 + + sub rsp, 512 # reserve space for 16 YMM registers (32 bytes) + +// +// Save YMM registers (YMM0 to YMM15) +// + + vmovdqu [rsp], ymm0 + vmovdqu [rsp+32], ymm1 + vmovdqu [rsp+64], ymm2 + vmovdqu [rsp+96], ymm3 + vmovdqu [rsp+128], ymm4 + vmovdqu [rsp+160], ymm5 + vmovdqu [rsp+192], ymm6 + vmovdqu [rsp+224], ymm7 + vmovdqu [rsp+256], ymm8 + vmovdqu [rsp+288], ymm9 + vmovdqu [rsp+320], ymm10 + vmovdqu [rsp+352], ymm11 + vmovdqu [rsp+384], ymm12 + vmovdqu [rsp+416], ymm13 + vmovdqu [rsp+448], ymm14 + vmovdqu [rsp+480], ymm15 + + lea rdi, [rsp+32*\VecReg1Num\()] # first operand (unsigned) + lea rsi, [rsp+32*\VecReg2Num\()] # second operand (signed) + + call CheckSaturationForVPMADDUBSW + +// +// Restore YMM registers +// + + vmovdqu ymm0, [rsp] + vmovdqu ymm1, [rsp+32] + vmovdqu ymm2, [rsp+64] + vmovdqu ymm3, [rsp+96] + vmovdqu ymm4, [rsp+128] + vmovdqu ymm5, [rsp+160] + vmovdqu ymm6, [rsp+192] + vmovdqu ymm7, [rsp+224] + vmovdqu ymm8, [rsp+256] + vmovdqu ymm9, [rsp+288] + vmovdqu ymm10, [rsp+320] + vmovdqu ymm11, [rsp+352] + vmovdqu ymm12, [rsp+384] + vmovdqu ymm13, [rsp+416] + vmovdqu ymm14, [rsp+448] + vmovdqu ymm15, [rsp+480] + + add rsp, 512 # clean up the reserved stack space + +// +// Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdi + pop rsi + pop rdx + pop rcx + pop rax + + .endm + /*++ Macro Description: @@ -52,9 +137,15 @@ Implicit Arguments: .macro MultiplyAccumulateRowAvx2 Vec1Reg, Vec2Reg +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,0 +#endif vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,1 +#endif vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 diff --git a/src/ort_include/core/common/common.h b/src/ort_include/core/common/common.h index 0822eba..820d140 100644 --- a/src/ort_include/core/common/common.h +++ b/src/ort_include/core/common/common.h @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch abort(); \ } while (false) +#define ORT_THROW_FROM_STATUS(status) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, status.ToString()) \ + .what()); \ + abort(); \ + } while (false) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) \ + .what()); \ + abort(); \ + } while (false) + #else #define ORT_TRY try @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch #define ORT_THROW_EX(ex, ...) \ throw ex(__VA_ARGS__) +#define ORT_THROW_FROM_STATUS(status) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \ + static_cast<::onnxruntime::common::StatusCode>(status.Code())) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) + #endif #define ORT_MAKE_STATUS(category, code, ...) \ @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch auto _status = (expr); \ if ((!_status.IsOK())) { \ ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ - ORT_THROW(_status); \ + ORT_THROW_FROM_STATUS(_status); \ } \ } while (0) @@ -264,15 +294,29 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif -constexpr size_t kMaxStrLen = 2048; +constexpr size_t kMaxStrLen = 4096; // Returns whether `key` is in `container`. // Like C++20's map/set contains() member function. diff --git a/src/ort_include/core/common/const_pointer_container.h b/src/ort_include/core/common/const_pointer_container.h index 1d821ba..80343b4 100644 --- a/src/ort_include/core/common/const_pointer_container.h +++ b/src/ort_include/core/common/const_pointer_container.h @@ -79,6 +79,10 @@ class ConstPointerContainer { return data_[index]; } + const T* const* data() const { + return data_.data(); + } + private: const Container& data_; }; diff --git a/src/ort_include/core/common/cpuid_arch_definition.h b/src/ort_include/core/common/cpuid_arch_definition.h index a541eb6..5946b8c 100644 --- a/src/ort_include/core/common/cpuid_arch_definition.h +++ b/src/ort_include/core/common/cpuid_arch_definition.h @@ -9,6 +9,6 @@ #define CPUIDINFO_ARCH_X86 #endif -#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) +#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 diff --git a/src/ort_include/core/common/cpuid_info.h b/src/ort_include/core/common/cpuid_info.h index 4c9e7e8..9c40627 100644 --- a/src/ort_include/core/common/cpuid_info.h +++ b/src/ort_include/core/common/cpuid_info.h @@ -15,6 +15,14 @@ class CPUIDInfo { return cpuid_info; } + std::string_view GetCPUVendor() const { + return vendor_; + } + + uint32_t GetCPUVendorId() const { + return vendor_id_; + } + bool HasAMX_BF16() const { return has_amx_bf16_; } bool HasAVX() const { return has_avx_; } bool HasAVX2() const { return has_avx2_; } @@ -25,12 +33,16 @@ class CPUIDInfo { bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } + bool HasTPAUSE() const { return has_tpause_; } // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSve() const { return has_arm_sve_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + bool HasArm_SME() const { return has_arm_sme_; } + bool HasArm_SME2() const { return has_arm_sme2_; } uint32_t GetCurrentCoreIdx() const; @@ -93,7 +105,40 @@ class CPUIDInfo { } private: + // Log function that uses ORT logging if available or writes to stderr. + // This enables us to log even before ORT logging has been initialized. + static void LogEarlyWarning(std::string_view message); + CPUIDInfo(); + + void VendorInfoInit(); + +#if defined(CPUIDINFO_ARCH_X86) + + void X86Init(); + +#elif defined(CPUIDINFO_ARCH_ARM) + +#if defined(__linux__) + + void ArmLinuxInit(); + +#elif defined(_WIN32) + + void ArmWindowsInit(); + +#elif defined(__APPLE__) + + void ArmAppleInit(); + +#endif + +#endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUINFO_SUPPORTED) + bool pytorch_cpuinfo_init_{false}; +#endif // defined(CPUINFO_SUPPORTED) + bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -104,6 +149,7 @@ class CPUIDInfo { bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; + bool has_tpause_{false}; std::vector core_uarchs_; // micro-arch of each core @@ -115,35 +161,14 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + bool has_arm_sme_{false}; + bool has_arm_sme2_{false}; -#if defined(CPUIDINFO_ARCH_X86) - - void X86Init(); - -#elif defined(CPUIDINFO_ARCH_ARM) - -#if defined(CPUINFO_SUPPORTED) - // Now the following var is only used in ARM build, but later on we may expand the usage. - bool pytorch_cpuinfo_init_{false}; -#endif // defined(CPUINFO_SUPPORTED) - -#if defined(__linux__) - - void ArmLinuxInit(); - -#elif defined(_WIN32) - - void ArmWindowsInit(); - -#elif defined(__APPLE__) - - void ArmAppleInit(); - -#endif - -#endif // defined(CPUIDINFO_ARCH_ARM) + std::string vendor_; + uint32_t vendor_id_; }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/src/ort_include/core/framework/endian.h b/src/ort_include/core/common/endian.h similarity index 100% rename from src/ort_include/core/framework/endian.h rename to src/ort_include/core/common/endian.h diff --git a/src/ort_include/core/common/exceptions.h b/src/ort_include/core/common/exceptions.h index 494a770..6d0f6ed 100644 --- a/src/ort_include/core/common/exceptions.h +++ b/src/ort_include/core/common/exceptions.h @@ -11,6 +11,7 @@ #include #include "core/common/common.h" +#include "core/common/status.h" #include "core/common/code_location.h" namespace onnxruntime { @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception { /** Create a new exception that captures the location it was thrown from. @param location Location in the source code the exception is being thrown from + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + + OnnxRuntimeException(const CodeLocation& location, + const std::string& message, + common::StatusCategory category, + common::StatusCode code) noexcept + : OnnxRuntimeException(location, nullptr, message, category, code) { + } + + /** + Create a new exception that captures the location it was thrown from. + The instance will be created with ONNXRUNTIME category and FAIL code. + @param location Location in the source code the exception is being thrown from @param failed_condition Optional string containing the condition that failed. e.g. "tensor.Size() == input.Size()". May be nullptr. @param msg Message containing additional information about the exception cause. */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept + : OnnxRuntimeException(location, failed_condition, msg, + common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg, + common::StatusCategory category, + common::StatusCode code) + : location_{location}, category_(category), code_(code) { std::ostringstream ss; ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception { what_ = ss.str(); } + common::StatusCategory Category() const noexcept { + return category_; + } + + common::StatusCode Code() const noexcept { + return code_; + } + const char* what() const noexcept override { return what_.c_str(); } @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception { const CodeLocation location_; const std::vector stacktrace_; std::string what_; + common::StatusCategory category_; + common::StatusCode code_; }; } // namespace onnxruntime diff --git a/src/ort_include/core/framework/float16.h b/src/ort_include/core/common/float16.h similarity index 97% rename from src/ort_include/core/framework/float16.h rename to src/ort_include/core/common/float16.h index dac0a01..7cca9be 100644 --- a/src/ort_include/core/framework/float16.h +++ b/src/ort_include/core/common/float16.h @@ -4,9 +4,9 @@ #include -#include "endian.h" +#include "core/common/endian.h" #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -#include "cuda_bf16.h" +#include "cuda_bf16.h" // from CUDA SDK #endif #if !defined(__CUDACC__) && !defined(__HIPCC__) @@ -261,19 +261,19 @@ struct BFloat16 : onnxruntime_float16::BFloat16Impl { // initializers with MLFloat16 and BFloat16 from unsigned short // E.g 10_f16 or 10_b16 #if !defined(__CUDACC__) && !defined(__HIPCC__) -inline MLFloat16 operator"" _f16(unsigned long long int v) noexcept { +inline MLFloat16 operator""_f16(unsigned long long int v) noexcept { return MLFloat16::FromBits(narrow(v)); } -inline MLFloat16 operator"" _fp16(long double v) noexcept { +inline MLFloat16 operator""_fp16(long double v) noexcept { return MLFloat16(static_cast(v)); } -inline BFloat16 operator"" _b16(unsigned long long int v) noexcept { +inline BFloat16 operator""_b16(unsigned long long int v) noexcept { return BFloat16::FromBits((narrow(v))); } -inline BFloat16 operator"" _bfp16(long double v) noexcept { +inline BFloat16 operator""_bfp16(long double v) noexcept { return BFloat16(static_cast(v)); } #endif diff --git a/src/ort_include/core/common/float8.h b/src/ort_include/core/common/float8.h new file mode 100644 index 0000000..7dde1d0 --- /dev/null +++ b/src/ort_include/core/common/float8.h @@ -0,0 +1,937 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header +// if they would like to leverage the CUDA implementation for the conversion routines +// in their HOST code (code compiled by MSVC/GCC). +// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h. +// We can't include cuda.h in this header unconditionally because this header is also +// included in core framework files which are CUDA-agnostic. +// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines +// implemented in this file. +// For code compiled by NVCC which includes this header, this file will automatically +// include cuda.h (based on the CUDA_CC macro). + +#pragma once + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/endian.h" + +#if defined(__CUDACC__) +// Needed for CUDA_VERSION check below +#include +#endif + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#include "cuda_fp8.h" +#endif + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +#include "core/common/narrow.h" +#endif + +#include "core/common/common.h" + +namespace onnxruntime { + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define ORT_HOST_DEVICE __host__ __device__ +#else +#define ORT_HOST_DEVICE +#endif + +// Float8E4M3FN +struct Float8E4M3FN { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E4M3FN() = default; +#else + Float8E4M3FN() = default; +#endif + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E4M3FN(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E4M3FN(float v, bool saturate = true) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E4M3); +#else + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = static_cast((b & 0x80000000) >> 24); // sign + if ((b & 0x7fffffff) == 0x7f800000) { // infinity + if (saturate) { + val |= 126; + } else { + val |= 0x7f; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; + } else { + uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent + uint32_t m = static_cast(b & 0x007FFFFF); // mantissa + if (e != 0) { + if (e < 117) { + } else if (e < 121) { + // denormalized number + auto d = 120 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 136) { + // normalized number + auto ex = e - 120; + if (ex == 0) { + val |= 0x4; + val |= m >> 21; + } else { + val |= ex << 3; + val |= m >> 20; + if ((val & 0x7F) == 0x7F) { + val &= 0xFE; + } + } + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { + if ((val & 0x7F) < 0x7E) { + // rounding + val += 1; + } else if (!saturate) { + val |= 0x7F; + } + } + } else if (saturate) { + val |= 126; // 0b01111110 + } else { + val |= 0x7F; + } + } + } +#endif + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return (val & 0b01111111) == 0b01111111; + } + + inline ORT_HOST_DEVICE float ToFloat() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)); +#else + uint32_t res; + if (val == 255) { + res = 0xffc00000; + } else if (val == 127) { + res = 0x7fc00000; + } else { + uint32_t expo = (val & 0x78) >> 3; + uint32_t mant = val & 0x07; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 7; + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x3) << 21; + res |= expo << 23; + } + } else { + res |= mant << 20; + expo -= 0x7; + expo += 0x7F; + res |= expo << 23; + } + } + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; +#endif + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast(&val); } +#endif +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E4M3FN operator""_f8e4m3fn(unsigned long long int v) { + return Float8E4M3FN(narrow(v), Float8E4M3FN::FromBits()); +} + +inline Float8E4M3FN operator""_f8e4m3fnp8(long double v) { + return Float8E4M3FN(static_cast(v), true); +} + +#endif + +inline void Float8E4M3FNToFloat(const Float8E4M3FN* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E4M3FN(const float* flt, Float8E4M3FN* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E4M3FN(*src, saturate); + } +} + +// Float8E4M3FNUZ +struct Float8E4M3FNUZ { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E4M3FNUZ() = default; +#else + Float8E4M3FNUZ() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E4M3FNUZ(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E4M3FNUZ(float v, bool saturate = true) { + // This type does not exist on CUDA. + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = static_cast((b & 0x80000000) >> 24); // sign + if ((b & 0x7fffffff) == 0x7f800000) { // infinity + if (saturate) { + // the highest available value + val |= 0x7F; + } else { + // NaN + val = 0x80; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; + } else { + uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent + uint32_t m = static_cast(b & 0x007FFFFF); // mantissa + + if (e < 116) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 120) { + // denormalized number + auto d = 119 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 135) { + // normalized number + auto ex = e - 119; + if (ex == 0) { + val |= 0x4; + val |= m >> 21; + } else { + val |= ex << 3; + val |= m >> 20; + } + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { + if ((val & 0x7F) < 0x7F) { + // rounding + val += 1; + } else if (!saturate) { + val = 0x80; + } + } + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + + inline ORT_HOST_DEVICE float ToFloat() const { + // This type does not exist on CUDA. + uint32_t res; + if (val == 0x80) { + res = 0xffc00000; + } else { + uint32_t expo = (val & 0x78) >> 3; + uint32_t mant = val & 0x07; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 8; + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x3) << 21; + res |= expo << 23; + } + } else { + res |= mant << 20; + expo -= 8; + expo += 0x7F; + res |= expo << 23; + } + } + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E4M3FNUZ operator""_f8e4m3p8fnuz(unsigned long long int v) { + return Float8E4M3FNUZ(narrow(v), Float8E4M3FNUZ::FromBits()); +} + +inline Float8E4M3FNUZ operator""_f8e4m3fnuzp8(long double v) { + return Float8E4M3FNUZ(static_cast(v), true); +} + +#endif + +inline void Float8E4M3FNUZToFloat(const Float8E4M3FNUZ* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E4M3FNUZ(const float* flt, Float8E4M3FNUZ* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E4M3FNUZ(*src, saturate); + } +} + +// Float8E5M2 +struct Float8E5M2 { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E5M2() = default; +#else + Float8E5M2() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E5M2(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E5M2(float v, bool saturate = true) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E5M2); +#else + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + if (saturate) { + // the highest available value + val |= 0x7B; + } else { + // the infinity + val |= 0x7C; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; + } else { + uint32_t e = (b & 0x7F800000) >> 23; // exponent + uint32_t m = b & 0x007FFFFF; // mantissa + + if (e != 0) { + if (e < 110) { + } else if (e < 113) { + // denormalized number + auto d = 112 - e; + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 143) { // 127 + 15 + 1 + auto ex = e - 112; // 127 - 15 + val |= ex << 2; + val |= m >> 21; + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { + if ((val & 0x7F) < 0x7B) { + // rounding + val += 1; + } else if (saturate) { + val |= 0x7B; + } else { + val |= 0x7C; + } + } + } else if (saturate) { + val |= 0x7B; + } else { + val |= 0x7C; + } + } + } +#endif + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + // 7D, 7E, 7F are positive NaNs; FD, FE, FF are negative NaNs + return (val & 0b01111111) > 0b01111100; + } + + inline ORT_HOST_DEVICE bool IsInfinity() const { + // 7C and FC are infinity + return (val & 0b01111111) == 0b01111100; + } + + inline ORT_HOST_DEVICE float ToFloat() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E5M2)); +#else + uint32_t res; + if (val >= 253) { + res = 0xffc00000; + } else if (val >= 125 && val <= 127) { + res = 0x7fc00000; + } else if (val == 252) { + res = 0xff800000; + } else if (val == 124) { + res = 0x7f800000; + } else { + uint32_t expo = (val & 0x7C) >> 2; + uint32_t mant = val & 0x03; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 15; + if ((mant & 0x2) == 0) { + mant &= 0x1; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x1) << 22; + res |= expo << 23; + } + } else { + res |= mant << 21; + expo -= 15; + expo += 0x7F; + res |= expo << 23; + } + } + + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; +#endif + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + ORT_HOST_DEVICE Float8E5M2(const __nv_fp8_e5m2& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE operator __nv_fp8_e5m2() const { return *reinterpret_cast(&val); } +#endif +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E5M2 operator""_f8e5m2fn(unsigned long long int v) { + return Float8E5M2(narrow(v), Float8E5M2::FromBits()); +} + +inline Float8E5M2 operator""_f8e5m2fnp8(long double v) { + return Float8E5M2(static_cast(v), true); +} + +#endif + +inline void Float8E5M2ToFloat(const Float8E5M2* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E5M2(const float* flt, Float8E5M2* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E5M2(*src, saturate); + } +} + +// Float8E5M2FNUZ +struct Float8E5M2FNUZ { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E5M2FNUZ() = default; +#else + Float8E5M2FNUZ() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E5M2FNUZ(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E5M2FNUZ(float v, bool saturate = true) { + // This type does not exist on CUDA. + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; + } else { + uint32_t e = (b & 0x7F800000) >> 23; // exponent + uint32_t m = b & 0x007FFFFF; // mantissa + + if (e < 109) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 112) { + // denormalized number + auto d = 111 - e; + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 143) { + // normalized number + auto ex = e - 111; + val |= ex << 2; + val |= m >> 21; + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { + if ((val & 0x7F) < 0x7F) { + // rounding + val += 1; + } else if (!saturate) { + val = 0x80; + } + } + } else if ((e == 255) && (m == 0)) { + val = 0x80; + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + + inline ORT_HOST_DEVICE float ToFloat() const { + // This type does not exist on CUDA. + uint32_t res; + if (val == 0x80) { + res = 0xffc00000; + } else { + uint32_t expo = (val & 0x7C) >> 2; + uint32_t mant = val & 0x03; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 16; + if ((mant & 0x2) == 0) { + mant &= 0x1; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x1) << 22; + res |= expo << 23; + } + } else { + res |= mant << 21; + expo -= 16; + expo += 0x7F; + res |= expo << 23; + } + } + + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E5M2FNUZ operator""_f8e5m2fnuz(unsigned long long int v) { + return Float8E5M2FNUZ(narrow(v), Float8E5M2FNUZ::FromBits()); +} + +inline Float8E5M2FNUZ operator""_f8e5m2fnuzp8(long double v) { + return Float8E5M2FNUZ(static_cast(v), true); +} + +#endif + +inline void Float8E5M2FNUZToFloat(const Float8E5M2FNUZ* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E5M2FNUZ(const float* flt, Float8E5M2FNUZ* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E5M2FNUZ(*src, saturate); + } +} + +} // namespace onnxruntime + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FN lowest() { + return onnxruntime::Float8E4M3FN(0xFE, onnxruntime::Float8E4M3FN::FromBits()); // -448 + } + + static constexpr onnxruntime::Float8E4M3FN max() { + return onnxruntime::Float8E4M3FN(0x7E, onnxruntime::Float8E4M3FN::FromBits()); // 448 + } + + static constexpr onnxruntime::Float8E4M3FN min() { + return onnxruntime::Float8E4M3FN(0x08, onnxruntime::Float8E4M3FN::FromBits()); // 2^-6 = 0.015625 + } + + static constexpr onnxruntime::Float8E4M3FN denorm_min() { + return onnxruntime::Float8E4M3FN(0x01, onnxruntime::Float8E4M3FN::FromBits()); // 2^-9 = 0.001953125 + } + + static constexpr onnxruntime::Float8E4M3FN epsilon() { + return onnxruntime::Float8E4M3FN(0x20, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN round_error() { + return onnxruntime::Float8E4M3FN(0x30, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FN quiet_NaN() { + return onnxruntime::Float8E4M3FN(0x7F, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2 lowest() { + return onnxruntime::Float8E5M2(0xFB, onnxruntime::Float8E5M2::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2 max() { + return onnxruntime::Float8E5M2(0x7B, onnxruntime::Float8E5M2::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2 min() { + return onnxruntime::Float8E5M2(0x4, onnxruntime::Float8E5M2::FromBits()); // 2^-14 = 0.00006103515 + } + + static constexpr onnxruntime::Float8E5M2 denorm_min() { + return onnxruntime::Float8E5M2(0x01, onnxruntime::Float8E5M2::FromBits()); // 2^-16 = 0.00001525878 + } + + static constexpr onnxruntime::Float8E5M2 epsilon() { + return onnxruntime::Float8E5M2(0x34, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 round_error() { + return onnxruntime::Float8E5M2(0x38, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 infinity() { + return onnxruntime::Float8E5M2(0x7C, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 quiet_NaN() { + return onnxruntime::Float8E5M2(0x7F, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FNUZ lowest() { + return onnxruntime::Float8E4M3FNUZ(0xFF, onnxruntime::Float8E4M3FNUZ::FromBits()); // -240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ max() { + return onnxruntime::Float8E4M3FNUZ(0x7F, onnxruntime::Float8E4M3FNUZ::FromBits()); // 240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ min() { + return onnxruntime::Float8E4M3FNUZ(0x08, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-7 = 0.0078125 + } + + static constexpr onnxruntime::Float8E4M3FNUZ denorm_min() { + return onnxruntime::Float8E4M3FNUZ(0x01, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-10 = 0.0009765625 + } + + static constexpr onnxruntime::Float8E4M3FNUZ epsilon() { + return onnxruntime::Float8E4M3FNUZ(0x28, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ round_error() { + return onnxruntime::Float8E4M3FNUZ(0x38, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FNUZ quiet_NaN() { + return onnxruntime::Float8E4M3FNUZ(0x80, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2FNUZ lowest() { + return onnxruntime::Float8E5M2FNUZ(0xFF, onnxruntime::Float8E5M2FNUZ::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ max() { + return onnxruntime::Float8E5M2FNUZ(0x7F, onnxruntime::Float8E5M2FNUZ::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ min() { + return onnxruntime::Float8E5M2FNUZ(0x04, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-15 = 0.00003051757 + } + + static constexpr onnxruntime::Float8E5M2FNUZ denorm_min() { + return onnxruntime::Float8E5M2FNUZ(0x01, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-17 = 0.00000762939 + } + + static constexpr onnxruntime::Float8E5M2FNUZ epsilon() { + return onnxruntime::Float8E5M2FNUZ(0x34, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ round_error() { + return onnxruntime::Float8E5M2FNUZ(0x38, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E5M2FNUZ quiet_NaN() { + return onnxruntime::Float8E5M2FNUZ(0x80, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +} // namespace std + +#endif // DISABLE_FLOAT8_TYPES diff --git a/src/ort_include/core/common/logging/logging.h b/src/ort_include/core/common/logging/logging.h index 571262a..508c22d 100644 --- a/src/ort_include/core/common/logging/logging.h +++ b/src/ort_include/core/common/logging/logging.h @@ -17,7 +17,6 @@ #include "core/common/logging/macros.h" #include "core/common/logging/severity.h" #include "core/common/logging/sink_types.h" -#include "core/platform/ort_mutex.h" /* @@ -51,8 +50,9 @@ */ -namespace onnxruntime { +struct OrtLogger; // opaque API type. is always an instance of Logger +namespace onnxruntime { namespace logging { using Timestamp = std::chrono::time_point; @@ -85,7 +85,7 @@ using Timestamp = std::chrono::time_point; #endif #endif // __APPLE__ -#if _WIN32 || ORT_USE_CXX20_STD_CHRONO +#if ORT_USE_CXX20_STD_CHRONO namespace timestamp_ns = std::chrono; #else namespace timestamp_ns = ::date; @@ -258,7 +258,7 @@ class LoggingManager final { std::unique_ptr sink_; #ifdef _WIN32 - mutable OrtMutex sink_mutex_; + mutable std::mutex sink_mutex_; #endif Severity default_min_severity_; const bool default_filter_user_data_; @@ -352,6 +352,10 @@ class Logger { logging_manager_->SendProfileEvent(eventRecord); } + // convert to API type for custom ops and plugin EPs + OrtLogger* ToExternal() { return reinterpret_cast(this); } + const OrtLogger* ToExternal() const { return reinterpret_cast(this); } + private: const LoggingManager* logging_manager_; const std::string id_; diff --git a/src/ort_include/core/common/parse_string.h b/src/ort_include/core/common/parse_string.h index 941e3f3..5f88d49 100644 --- a/src/ort_include/core/common/parse_string.h +++ b/src/ort_include/core/common/parse_string.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -12,18 +13,62 @@ namespace onnxruntime { +namespace detail { + +// Whether we will use std::from_chars() to parse to `T`. +#if defined(_LIBCPP_VERSION) +// Note: Currently (e.g., in LLVM 19), libc++'s std::from_chars() doesn't support floating point types yet. +template +constexpr bool ParseWithFromChars = !std::is_same_v && std::is_integral_v; +#else +template +constexpr bool ParseWithFromChars = !std::is_same_v && (std::is_integral_v || std::is_floating_point_v); +#endif + +} // namespace detail + /** * Tries to parse a value from an entire string. + * If successful, sets `value` and returns true. Otherwise, does not modify `value` and returns false. */ template -bool TryParseStringWithClassicLocale(std::string_view str, T& value) { - if constexpr (std::is_integral::value && std::is_unsigned::value) { - // if T is unsigned integral type, reject negative values which will wrap - if (!str.empty() && str[0] == '-') { - return false; +std::enable_if_t, bool> +TryParseStringWithClassicLocale(std::string_view str, T& value) { + T parsed_value{}; + + std::from_chars_result conversion_result{}; + if constexpr (std::is_integral_v && std::is_unsigned_v) { + // For unsigned integral types, also handle hex values, i.e., those beginning with "0x". + // std::from_chars() does not accept the "0x" prefix. + const bool has_hex_prefix = str.size() >= 2 && + str[0] == '0' && + (str[1] == 'x' || str[1] == 'X'); + + if (has_hex_prefix) { + str = str.substr(2); } + + const int base = has_hex_prefix ? 16 : 10; + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value, base); + } else { + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value); } + if (conversion_result.ec != std::errc{}) { + return false; + } + + if (conversion_result.ptr != str.data() + str.size()) { + return false; + } + + value = parsed_value; + return true; +} + +template +std::enable_if_t, bool> +TryParseStringWithClassicLocale(std::string_view str, T& value) { // don't allow leading whitespace if (!str.empty() && std::isspace(str[0], std::locale::classic())) { return false; diff --git a/src/ort_include/core/common/path_string.h b/src/ort_include/core/common/path_string.h index 6cfb327..4ca326d 100644 --- a/src/ort_include/core/common/path_string.h +++ b/src/ort_include/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/src/ort_include/core/common/profiler_common.h b/src/ort_include/core/common/profiler_common.h index 0074d5e..ab97325 100644 --- a/src/ort_include/core/common/profiler_common.h +++ b/src/ort_include/core/common/profiler_common.h @@ -81,8 +81,8 @@ class EpProfiler { virtual ~EpProfiler() = default; virtual bool StartProfiling(TimePoint profiling_start_time) = 0; // called when profiling starts virtual void EndProfiling(TimePoint start_time, Events& events) = 0; // called when profiling ends, save all captures numbers to "events" - virtual void Start(uint64_t){}; // called before op start, accept an id as argument to identify the op - virtual void Stop(uint64_t){}; // called after op stop, accept an id as argument to identify the op + virtual void Start(uint64_t) {} // called before op start, accept an id as argument to identify the op + virtual void Stop(uint64_t) {} // called after op stop, accept an id as argument to identify the op }; // Demangle C++ symbols diff --git a/src/ort_include/core/common/spin_pause.h b/src/ort_include/core/common/spin_pause.h index 49b71e5..4d987f1 100644 --- a/src/ort_include/core/common/spin_pause.h +++ b/src/ort_include/core/common/spin_pause.h @@ -3,26 +3,11 @@ #pragma once -#if defined(_M_AMD64) -#include -#endif - -#if defined(__x86_64__) -#include -#endif - namespace onnxruntime { - namespace concurrency { // Intrinsic to use in spin-loops - -inline void SpinPause() { -#if defined(_M_AMD64) || defined(__x86_64__) - _mm_pause(); -#endif -} +void SpinPause(); } // namespace concurrency - } // namespace onnxruntime diff --git a/src/ort_include/core/common/status.h b/src/ort_include/core/common/status.h index 8f171da..8cf6420 100644 --- a/src/ort_include/core/common/status.h +++ b/src/ort_include/core/common/status.h @@ -43,7 +43,10 @@ enum StatusCode { MODEL_LOADED = 8, NOT_IMPLEMENTED = 9, INVALID_GRAPH = 10, - EP_FAIL = 11 + EP_FAIL = 11, + MODEL_LOAD_CANCELED = 12, + MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -72,6 +75,12 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "INVALID_GRAPH"; case StatusCode::EP_FAIL: return "EP_FAIL"; + case StatusCode::MODEL_LOAD_CANCELED: + return "MODEL_LOAD_CANCELED"; + case StatusCode::MODEL_REQUIRES_COMPILATION: + return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -104,6 +113,12 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); case StatusCode::EP_FAIL: return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::MODEL_LOAD_CANCELED: + return HRESULT_FROM_WIN32(ERROR_CANCELLED); + case StatusCode::MODEL_REQUIRES_COMPILATION: + return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/src/ort_include/core/common/string_helper.h b/src/ort_include/core/common/string_helper.h index 1304303..c0b331c 100644 --- a/src/ort_include/core/common/string_helper.h +++ b/src/ort_include/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/src/ort_include/core/framework/callback.h b/src/ort_include/core/framework/callback.h deleted file mode 100644 index e69de29..0000000 diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 26237b3..c313944 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -50,7 +50,6 @@ #include "core/common/denormal.h" #include "core/common/inlined_containers_fwd.h" #include "core/common/spin_pause.h" -#include "core/platform/ort_mutex.h" #include "core/platform/ort_spin_lock.h" // ORT thread pool overview @@ -200,100 +199,6 @@ struct PaddingToAvoidFalseSharing { char padding[ORT_FALSE_SHARING_BYTES]; }; -/* Usage: -1. In executor, call Start() before profiling and Stop() to get profiled numbers; -2. Inside thread pool, call LogStart() before interested section and LogEnd... after to log elapsed time; -3. To extend, just add more events in enum Event before "All", and update GetEventName(...) accordingly; -4. Note LogStart must pair with either LogEnd or LogEndAndStart, otherwise ORT_ENFORCE will fail; -5. ThreadPoolProfiler is thread-safe. -*/ -#ifdef ORT_MINIMAL_BUILD -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int, const CHAR_TYPE*) {}; - ~ThreadPoolProfiler() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - void Start() {}; - std::string Stop() { return "not available for minimal build"; } - void LogStart() {}; - void LogEnd(ThreadPoolEvent){}; - void LogEndAndStart(ThreadPoolEvent){}; - void LogStartAndCoreAndBlock(std::ptrdiff_t){}; - void LogCoreAndBlock(std::ptrdiff_t){}; - void LogThreadId(int) {}; - void LogRun(int) {}; - std::string DumpChildThreadStat() { return {}; } -}; -#else -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int num_threads, const CHAR_TYPE* threal_pool_name); - ~ThreadPoolProfiler(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - using Clock = std::chrono::high_resolution_clock; - void Start(); // called by executor to start profiling - std::string Stop(); // called by executor to stop profiling and return collected numbers - void LogStart(); // called in main thread to record the starting time point - void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point - void LogEndAndStart(ThreadPoolEvent); - void LogStartAndCoreAndBlock(std::ptrdiff_t block_size); - void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown - void LogThreadId(int thread_idx); // called in child thread to log its id - void LogRun(int thread_idx); // called in child thread to log num of run - std::string DumpChildThreadStat(); // return all child statitics collected so far - - private: - static const char* GetEventName(ThreadPoolEvent); - struct MainThreadStat { - uint64_t events_[MAX_EVENT] = {}; - int32_t core_ = -1; - std::vector blocks_; // block size determined by cost model - std::vector points_; - void LogCore(); - void LogBlockSize(std::ptrdiff_t block_size); - void LogStart(); - void LogEnd(ThreadPoolEvent); - void LogEndAndStart(ThreadPoolEvent); - std::string Reset(); - }; - bool enabled_ = false; - MainThreadStat& GetMainThreadStat(); // return thread local stat - int num_threads_; -#ifdef _MSC_VER -#pragma warning(push) - // C4324: structure was padded due to alignment specifier -#pragma warning(disable : 4324) -#endif // _MSC_VER - struct ORT_ALIGN_TO_AVOID_FALSE_SHARING ChildThreadStat { - std::thread::id thread_id_; - uint64_t num_run_ = 0; - onnxruntime::TimePoint last_logged_point_ = Clock::now(); - int32_t core_ = -1; // core that the child thread is running on - }; -#ifdef _MSC_VER -#pragma warning(pop) -#endif // _MSC_VER - std::vector child_thread_stats_; - std::string thread_pool_name_; -}; -#endif - // Extended Eigen thread pool interface, avoiding the need to modify // the ThreadPoolInterface.h header from the external Eigen // repository. @@ -336,8 +241,6 @@ class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface { // two loops execute in series in a parallel section. ] virtual void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) = 0; - virtual void StartProfiling() = 0; - virtual std::string StopProfiling() = 0; }; class ThreadPoolParallelSection { @@ -459,7 +362,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); Elem& e = array_[(back - 1) & kMask]; @@ -483,7 +386,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); w_idx = (back - 1) & kMask; @@ -508,7 +411,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back; Elem* e; @@ -554,7 +457,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif Elem& e = array_[w_idx]; ElemState s = e.state.load(std::memory_order_relaxed); @@ -630,7 +533,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE OrtSpinLock spin_lock_; #else - OrtMutex mutex_; + std::mutex mutex_; #endif // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of @@ -706,7 +609,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter return 0; } - ThreadPoolProfiler profiler_; void SignalAllAndWait() { done_ = true; @@ -721,13 +623,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } public: - void StartProfiling() override { - profiler_.Start(); - } - std::string StopProfiling() override { - return profiler_.Stop(); - } struct Tag { constexpr Tag() : v_(0) { @@ -739,7 +635,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Allocate a new tag to use to identify work items from a given // thread in a parallel section. Ideally, threads will have // unique tags, but re-use is not incorrect if the counter wraps - // (for intsance, if a long-running workload is calling into ORT + // (for instance, if a long-running workload is calling into ORT // from a fresh thread for each request). We must not re-use the // default tag 0 which is used to identify work items added via // Schedule as opposed to requests for help in parallel sections. @@ -768,7 +664,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env, const ThreadOptions& thread_options) - : profiler_(num_threads, name), + : env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), @@ -912,11 +808,10 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } } - // Now we know that dispatch is finshed, we synchronize with the + // Now we know that dispatch is finished, we synchronize with the // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. - profiler_.LogStart(); unsigned tasks_started = static_cast(ps.tasks.size()); while (!ps.tasks.empty()) { const auto& item = ps.tasks.back(); @@ -926,7 +821,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } ps.tasks.pop_back(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT_REVOKE); // Wait for the dispatch task's own work... if (ps.dispatch_q_idx > -1) { @@ -1004,7 +898,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // * First, suppose that a single job is running a series of loops. // Its main thread enters a parallel loop. Initially, let's assume // its preferred worker array is [_,0,1,2], writing "_" for the - // unusued element for the par_idx=0 work that the main thread will + // unused element for the par_idx=0 work that the main thread will // run. // // The main thread schedules the dispatcher task onto worker 0. @@ -1205,7 +1099,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ps.work_done.store(true, std::memory_order_release); }; - profiler_.LogStart(); ps.dispatch_q_idx = preferred_workers[current_dop] % num_threads_; WorkerData& dispatch_td = worker_data_[ps.dispatch_q_idx]; Queue& dispatch_que = dispatch_td.queue; @@ -1223,7 +1116,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } else { ps.dispatch_q_idx = -1; // failed to enqueue dispatch_task } - profiler_.LogEnd(ThreadPoolProfiler::DISTRIBUTION_ENQUEUE); } else { // Synchronous dispatch ScheduleOnPreferredWorkers(pt, ps, preferred_workers, current_dop, new_dop, std::move(worker_fn)); @@ -1241,7 +1133,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); assert(pt->leading_par_section && "RunInParallel, but not in parallel section"); assert((n > 1) && "Trivial parallel section; should be avoided by caller"); @@ -1271,18 +1162,15 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter }; RunInParallelInternal(*pt, ps, n, false, std::move(worker_fn)); assert(ps.dispatch_q_idx == -1); - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); // Run work in the main thread loop.fn(0); - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); // Wait for workers to exit the loop ps.current_loop = 0; while (ps.workers_in_loop) { onnxruntime::concurrency::SpinPause(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } // Run a single parallel loop _without_ a parallel section. This is a @@ -1299,16 +1187,12 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // 1. run fn(...); void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); ThreadPoolParallelSection ps; StartParallelSectionInternal(*pt, ps); RunInParallelInternal(*pt, ps, n, true, fn); // select dispatcher and do job distribution; - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); fn(0); // run fn(0) - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); EndParallelSectionInternal(*pt, ps); // wait for all - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } int NumThreads() const final { @@ -1439,7 +1323,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadStatus seen = GetStatus(); if (seen == ThreadStatus::Blocking || seen == ThreadStatus::Blocked) { - std::unique_lock lk(mutex); + std::unique_lock lk(mutex); // Blocking state exists only transiently during the SetBlock() method // while holding the lock. We may observe it at the start of this // function, but after acquiring the lock then the target thread @@ -1467,11 +1351,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter status = ThreadStatus::Spinning; } - void SetBlocked(std::function should_block, + bool SetBlocked(std::function should_block, std::function post_block) { - std::unique_lock lk(mutex); - assert(GetStatus() == ThreadStatus::Spinning); - status.store(ThreadStatus::Blocking, std::memory_order_relaxed); + std::unique_lock lk(mutex); + auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst); + if (old_status != ThreadStatus::Spinning) { + // Encountered a logical error + return false; + } if (should_block()) { status.store(ThreadStatus::Blocked, std::memory_order_relaxed); do { @@ -1480,12 +1367,13 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter post_block(); } status.store(ThreadStatus::Spinning, std::memory_order_relaxed); + return true; } private: std::atomic status{ThreadStatus::Spinning}; - OrtMutex mutex; - OrtCondVar cv; + std::mutex mutex; + std::condition_variable cv; }; Environment& env_; @@ -1536,7 +1424,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter const int steal_count = spin_count / 100; SetDenormalAsZero(set_denormal_as_zero_); - profiler_.LogThreadId(thread_id); while (!should_exit) { Task t = q.PopFront(); @@ -1558,62 +1445,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Attempt to block if (!t) { - td.SetBlocked( // Pre-block test - [&]() -> bool { - bool should_block = true; - // Check whether work was pushed to us while attempting to block. We make - // this test while holding the per-thread status lock, and after setting - // our status to ThreadStatus::Blocking. - // - // This synchronizes with ThreadPool::Schedule which pushes work to the queue - // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): - // - // Main thread: Worker: - // #1 Push work #A Set status blocking - // #2 Read worker status #B Check queue - // #3 Wake if blocking/blocked - // - // If #A is before #2 then main sees worker blocked and wakes - // - // If #A if after #2 then #B will see #1, and we abandon blocking - assert(!t); - t = q.PopFront(); - if (t) { - should_block = false; - } - - // No work pushed to us, continue attempting to block. The remaining - // test is to synchronize with termination requests. If we are - // shutting down and all worker threads blocked without work, that's - // we are done. - if (should_block) { - blocked_++; - if (done_ && blocked_ == num_threads_) { - should_block = false; - // Almost done, but need to re-check queues. - // Consider that all queues are empty and all worker threads are preempted - // right after incrementing blocked_ above. Now a free-standing thread - // submits work and calls destructor (which sets done_). If we don't - // re-check queues, we will exit leaving the work unexecuted. - if (NonEmptyQueueIndex() != -1) { - // Note: we must not pop from queues before we decrement blocked_, - // otherwise the following scenario is possible. Consider that instead - // of checking for emptiness we popped the only element from queues. - // Now other worker threads can start exiting, which is bad if the - // work item submits other work. So we just check emptiness here, - // which ensures that all worker threads exit at the same time. - blocked_--; - } else { - should_exit = true; + if (!td.SetBlocked( // Pre-block test + [&]() -> bool { + bool should_block = true; + // Check whether work was pushed to us while attempting to block. We make + // this test while holding the per-thread status lock, and after setting + // our status to ThreadStatus::Blocking. + // + // This synchronizes with ThreadPool::Schedule which pushes work to the queue + // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): + // + // Main thread: Worker: + // #1 Push work #A Set status blocking + // #2 Read worker status #B Check queue + // #3 Wake if blocking/blocked + // + // If #A is before #2 then main sees worker blocked and wakes + // + // If #A if after #2 then #B will see #1, and we abandon blocking + assert(!t); + t = q.PopFront(); + if (t) { + should_block = false; + } + + // No work pushed to us, continue attempting to block. The remaining + // test is to synchronize with termination requests. If we are + // shutting down and all worker threads blocked without work, that's + // we are done. + if (should_block) { + blocked_++; + if (done_ && blocked_ == num_threads_) { + should_block = false; + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + } else { + should_exit = true; + } + } } - } - } - return should_block; - }, - // Post-block update (executed only if we blocked) - [&]() { - blocked_--; - }); + return should_block; + }, + // Post-block update (executed only if we blocked) + [&]() { + blocked_--; + })) { + // Encountered a fatal logic error in SetBlocked + should_exit = true; + break; + } // Thread just unblocked. Unless we picked up work while // blocking, or are exiting, then either work was pushed to // us, or it was pushed to an overloaded queue @@ -1625,7 +1516,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter if (t) { td.SetActive(); t(); - profiler_.LogRun(thread_id); td.SetSpinning(); } } diff --git a/src/ort_include/core/platform/env.h b/src/ort_include/core/platform/env.h index 47e74ea..c73c64f 100644 --- a/src/ort_include/core/platform/env.h +++ b/src/ort_include/core/platform/env.h @@ -26,7 +26,6 @@ limitations under the License. #include "core/common/common.h" #include "core/common/path_string.h" -#include "core/framework/callback.h" #include "core/session/onnxruntime_c_api.h" #ifndef _WIN32 @@ -37,7 +36,9 @@ namespace Eigen { class ThreadPoolInterface; } namespace onnxruntime { - +namespace concurrency { + inline void SpinPause(){} +} #ifdef _WIN32 using PIDType = unsigned long; using FileOffsetType = int64_t; @@ -78,7 +79,6 @@ struct ThreadOptions { // Set or unset denormal as zero. bool set_denormal_as_zero = false; - }; std::ostream& operator<<(std::ostream& os, const LogicalProcessors&); @@ -137,73 +137,12 @@ class Env { virtual int GetL2CacheSize() const = 0; - /// Sleeps/delays the thread for the prescribed number of micro-seconds. - /// On Windows, it's the min time to sleep, not the actual one. - virtual void SleepForMicroseconds(int64_t micros) const = 0; - - /** - * Gets the length of the specified file. - */ - virtual common::Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const = 0; - virtual common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const = 0; - - /** - * Copies the content of the file into the provided buffer. - * @param file_path The path to the file. - * @param offset The file offset from which to start reading. - * @param length The length in bytes to read. - * @param buffer The buffer in which to write. - */ - virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - gsl::span buffer) const = 0; - - - /** Gets the canonical form of a file path (symlinks resolved). */ - virtual common::Status GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const = 0; // This functions is always successful. It can't fail. virtual PIDType GetSelfPid() const = 0; - // \brief Load a dynamic library. - // - // Pass "library_filename" to a platform-specific mechanism for dynamically - // loading a library. The rules for determining the exact location of the - // library are platform-specific and are not documented here. - // - // global_symbols only has an effect on unix, where a value of true means to load with RTLD_GLOBAL vs RTLD_LOCAL - // - // On success, returns a handle to the library in "*handle" and returns - // OK from the function. - // Otherwise returns nullptr in "*handle" and an error status from the - // function. - virtual common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const = 0; - - virtual common::Status UnloadDynamicLibrary(void* handle) const = 0; - - // \brief Gets the file path of the onnx runtime code - // - // Used to help load other shared libraries that live in the same folder as the core code, for example - // The DNNL provider shared library. Without this path, the module won't be found on windows in all cases. - virtual PathString GetRuntimePath() const { return PathString(); } - - // \brief Get a pointer to a symbol from a dynamic library. - // - // "handle" should be a pointer returned from a previous call to LoadDynamicLibrary. - // On success, store a pointer to the located symbol in "*symbol" and return - // OK from the function. Otherwise, returns nullptr in "*symbol" and an error - // status from the function. - virtual common::Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const = 0; - - // \brief build the name of dynamic library. - // - // "name" should be name of the library. - // "version" should be the version of the library or NULL - // returns the name that LoadDynamicLibrary() can use - virtual std::string FormatLibraryFileName(const std::string& name, const std::string& version) const = 0; // \brief returns a value for the queried variable name (var_name) // diff --git a/src/ort_include/core/platform/ort_mutex.h b/src/ort_include/core/platform/ort_mutex.h deleted file mode 100644 index 5028b03..0000000 --- a/src/ort_include/core/platform/ort_mutex.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include -#include - -namespace onnxruntime{ - using OrtMutex = std::mutex; - using OrtCondVar = std::condition_variable; -} \ No newline at end of file diff --git a/src/ort_include/core/platform/windows/TraceLoggingConfig.h b/src/ort_include/core/platform/windows/TraceLoggingConfig.h deleted file mode 100644 index 2987167..0000000 --- a/src/ort_include/core/platform/windows/TraceLoggingConfig.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -/* ++ -Module Name: - TraceLoggingConfig.h -Abstract: - Macro definitions used by this project's TraceLogging ETW providers: - - Configuration macros that select the ETW Provider Groups to be used by - this project. - - Constants for tags that are commonly used in Microsoft's - TraceLogging-based ETW. - Different versions of this file use different definitions for the - TraceLoggingOption configuration macros. The definitions in this file are - empty. As a result, providers using this configuration file will not join - any ETW Provider Groups and will not be given any special treatment by - group-sensitive ETW listeners. -Environment: - User mode or kernel mode. ---*/ - -#pragma once - -// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition -// in this file configures the provider as a normal (non-telemetry) provider. -#ifndef TraceLoggingOptionMicrosoftTelemetry -#define TraceLoggingOptionMicrosoftTelemetry() \ - TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000) -// Empty definition for TraceLoggingOptionMicrosoftTelemetry -#endif - -// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition -// in this file configures the provider as a normal (non-telemetry) provider. -#define TraceLoggingOptionWindowsCoreTelemetry() \ - // Empty definition for TraceLoggingOptionWindowsCoreTelemetry - -// Event privacy tags. Use the PDT macro values for the tag parameter, e.g.: -// TraceLoggingWrite(..., -// TelemetryPrivacyDataTag(PDT_BrowsingHistory | PDT_ProductAndServiceUsage), -// ...); -#define TelemetryPrivacyDataTag(tag) TraceLoggingUInt64((tag), "PartA_PrivTags") -#define PDT_BrowsingHistory 0x0000000000000002u -#define PDT_DeviceConnectivityAndConfiguration 0x0000000000000800u -#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u -#define PDT_ProductAndServicePerformance 0x0000000001000000u -#define PDT_ProductAndServiceUsage 0x0000000002000000u -#define PDT_SoftwareSetupAndInventory 0x0000000080000000u - -// Event categories specified via keywords, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), -// ...); -#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47 -#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46 -#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45 -#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment) - -// Event categories specified via event tags, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingEventTag(MICROSOFT_EVENTTAG_REALTIME_LATENCY), -// ...); -#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000 -#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000 -#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000 -#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000 -#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000 -#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000 -#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000 -#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000 -#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000 -#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000 -#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000 -#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000 -#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000 - -// Field categories specified via field tags, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingString(szUser, "UserName", "User's name", MICROSOFT_FIELDTAG_HASH_PII), -// ...); -#define MICROSOFT_FIELDTAG_DROP_PII 0x04000000 -#define MICROSOFT_FIELDTAG_HASH_PII 0x08000000 diff --git a/src/ort_include/core/platform/windows/readme.txt b/src/ort_include/core/platform/windows/readme.txt deleted file mode 100644 index f1a436f..0000000 --- a/src/ort_include/core/platform/windows/readme.txt +++ /dev/null @@ -1,2 +0,0 @@ -copied from minkernel/published/internal/telemetry/open_source/TraceLoggingConfig.h -this is the official open source edition for these configuration settings \ No newline at end of file diff --git a/src/ort_include/core/session/onnxruntime_c_api.h b/src/ort_include/core/session/onnxruntime_c_api.h index a187fbd..40fab45 100644 --- a/src/ort_include/core/session/onnxruntime_c_api.h +++ b/src/ort_include/core/session/onnxruntime_c_api.h @@ -31,17 +31,4 @@ #define ORTCHAR_T wchar_t #else #define ORTCHAR_T char -#endif - -/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. -/// All other strings are UTF-8 encoded, use char and std::string -#ifndef ORT_TSTR -#ifdef _WIN32 -#define ORT_TSTR(X) L##X -// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. -#define ORT_TSTR_ON_MACRO(X) L"" X -#else -#define ORT_TSTR(X) X -#define ORT_TSTR_ON_MACRO(X) X -#endif #endif \ No newline at end of file diff --git a/src/ort_include/core/util/thread_utils.h b/src/ort_include/core/util/thread_utils.h index b146c0d..c25f789 100644 --- a/src/ort_include/core/util/thread_utils.h +++ b/src/ort_include/core/util/thread_utils.h @@ -19,7 +19,13 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. +#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; +#else + // default allow_spinning to false for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + bool allow_spinning = false; +#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/tests/bench/CMakeLists.txt b/tests/bench/CMakeLists.txt index 8bf535d..75c5836 100644 --- a/tests/bench/CMakeLists.txt +++ b/tests/bench/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(../../src/lib) -add_executable(onnxruntime_mlas_benchmark bench_computesoftmax.cpp bench_main.cpp bench_q4dq.cpp bench_q4gemm.cpp bench_qgemm.cpp bench_sconv.cpp bench_sgemm.cpp bench_sqnbitgemm.cpp bench_symm_qgemm.cpp bench_util.cpp) +add_executable(onnxruntime_mlas_benchmark bench_computesoftmax.cpp bench_main.cpp bench_q4dq.cpp bench_q4gemm.cpp bench_qgemm.cpp bench_sconv.cpp bench_sgemm.cpp bench_qnbitgemm.cpp bench_symm_qgemm.cpp bench_util.cpp) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark ${ONNXRUNTIME_MLAS_LIBS} ) if(NOT MLAS_NO_ONNXRUNTIME) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE onnxruntime_common) diff --git a/tests/bench/bench_cast.cpp b/tests/bench/bench_cast.cpp new file mode 100644 index 0000000..e323346 --- /dev/null +++ b/tests/bench/bench_cast.cpp @@ -0,0 +1,54 @@ +#include "bench_util.h" +#include "mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +void BM_ConvertF16ToF32(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, 0, 60000); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + float* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + } +} + +void BM_ConvertF32ToF16(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + } +} + +BENCHMARK(BM_ConvertF16ToF32) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +BENCHMARK(BM_ConvertF32ToF16) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/bench/bench_computesoftmax.cpp b/tests/bench/bench_computesoftmax.cpp index 57ab53c..bf30db4 100644 --- a/tests/bench/bench_computesoftmax.cpp +++ b/tests/bench/bench_computesoftmax.cpp @@ -5,19 +5,7 @@ #include "core/util/thread_utils.h" #include "bench_util.h" -#ifndef BUILD_MLAS_NO_ONNXRUNTIME using onnxruntime::narrow; -#else -using gsl::narrow; -#define ORT_THROW(X) throw std::runtime_error(X) -#define ORT_ENFORCE(condition, ...) \ - do { \ - if (!(condition)) { \ - abort(); \ - } \ - } while (false) -#define ORT_THROW_EX(X) throw X(); -#endif struct RestrictAlignedPtr { float* ptr; // Aligned pointer within the underlying buffer @@ -70,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/tests/bench/bench_fp16_neon_common.cpp b/tests/bench/bench_fp16_neon_common.cpp index 1dccbe4..e323346 100644 --- a/tests/bench/bench_fp16_neon_common.cpp +++ b/tests/bench/bench_fp16_neon_common.cpp @@ -1,5 +1,5 @@ #include "bench_util.h" -#include "core/mlas/lib/mlasi.h" +#include "mlasi.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/bench/bench_hgemm.cpp b/tests/bench/bench_hgemm.cpp new file mode 100644 index 0000000..f42c5e5 --- /dev/null +++ b/tests/bench/bench_hgemm.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "bench_util.h" +#include "core/util/thread_utils.h" + +#include +#include + +static const std::vector hgemm_bench_arg_names = {"M", "N", "K"}; + +void HGEMM(benchmark::State& state, bool transA, bool transB) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + + auto A = RandomVectorUniform(static_cast(M * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + auto B = RandomVectorUniform(static_cast(N * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + std::vector C(static_cast(M * N)); + + MLAS_FP16 alpha = MLAS_FP16(1.0f); + MLAS_FP16 beta = MLAS_FP16(0.0f); + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + + for (auto _ : state) { + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + } +} + +static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); +} +BENCHMARK_CAPTURE(HGEMM, GEMV_TransB, false, true)->Apply(GemmSizeWithOne)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, GEMV_B, false, false)->Apply(GemmSizeWithOne)->UseRealTime(); + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); +} +BENCHMARK_CAPTURE(HGEMM, NORMAL_TransB, false, true)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, NORMAL_B, false, false)->Apply(GemmSizeProducts)->UseRealTime(); + +static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); +} +BENCHMARK_CAPTURE(HGEMM, LLM_TransB, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, LLM_B, false, false)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/tests/bench/bench_sqnbitgemm.cpp b/tests/bench/bench_qnbitgemm.cpp similarity index 51% rename from tests/bench/bench_sqnbitgemm.cpp rename to tests/bench/bench_qnbitgemm.cpp index 71db7d8..8ad3b59 100644 --- a/tests/bench/bench_sqnbitgemm.cpp +++ b/tests/bench/bench_qnbitgemm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "benchmark/benchmark.h" @@ -16,22 +17,22 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); +template +void RunQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { + if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { + state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); return; } size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); @@ -43,40 +44,40 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + const auto A = RandomVectorUniform(M * K, AType(-1.0f), AType(1.0f)); + const auto B = RandomVectorUniform(K * N, AType(-1.0f), AType(1.0f)); - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); + const auto Bias = HasBias ? RandomVectorUniform(N, AType(-1.0f), AType(1.0f)) : std::vector(); - std::vector C(static_cast(M * N)); + std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); bool has_zp_input = !Symmetric; - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, !Symmetric, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), + tp.get()); } - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -92,15 +93,15 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } -template -void SQNBITGEMM(benchmark::State& state) { +template +void QNBITGEMM(benchmark::State& state) { using onnxruntime::narrow; const auto BlkLen = narrow(state.range(0)); @@ -110,46 +111,51 @@ void SQNBITGEMM(benchmark::State& state) { const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); + const auto ComputeType = static_cast(state.range(7)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +template +static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + {128}, // BlkLen + {1, 4096}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + std::is_same_v + ? std::vector{int64_t{HQNBIT_CompFp16}} + : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { +template +void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", + static_cast(SQNBIT_CompFp32)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen @@ -159,4 +165,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); +BENCHMARK(QNBITGEMM_ENV)->UseRealTime(); diff --git a/tests/bench/bench_rope.cpp b/tests/bench/bench_rope.cpp new file mode 100644 index 0000000..216ee79 --- /dev/null +++ b/tests/bench/bench_rope.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "benchmark/benchmark.h" +#include "bench_util.h" +#include "core/common/float16.h" + +using namespace onnxruntime; + +template +void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& state) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1.0f); + } + for (size_t i = 0; i < table_len; ++i) { + // https://arxiv.org/pdf/2104.09864 section 3.4.3 + float theta_i = static_cast(pow(10000, -2.0f * i / rotary_emb_dim)); + sin_data[i] = static_cast(std::sin(theta_i)); + cos_data[i] = static_cast(std::cos(theta_i)); + } + + // warm up run + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (auto _ : state) { + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + } +} + +template +void RoPE(benchmark::State& state) { + using onnxruntime::narrow; + + const auto rotary_emb_dim = narrow(state.range(0)); + const auto interleaved = narrow(state.range(1)); + + RunRoPEBenchmark(rotary_emb_dim, interleaved, state); +} + +template +static void RoPEArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"rotary_emb_dim", "interleaved"}); + + b->ArgsProduct({ + {128, 256, 512, 1024}, // rotary_emb_dim + {int64_t{false}, int64_t{true}}, // interleaved + }); +} + +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); diff --git a/tests/bench/bench_sconv.cpp b/tests/bench/bench_sconv.cpp index 39d1352..dc37980 100644 --- a/tests/bench/bench_sconv.cpp +++ b/tests/bench/bench_sconv.cpp @@ -3,6 +3,7 @@ #include "mlas.h" #include "bench_util.h" +#include "core/util/thread_utils.h" #include #include @@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { } } +static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) { + static auto threadpool = std::make_unique( + &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true); + return threadpool.get(); +} + +void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { + MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark(); + + const int64_t rank = state.range(0); // Rank + const int64_t batch_size = state.range(1); // N + const int64_t groups = state.range(2); // G + const int64_t input_channels_per_group = state.range(3); // Cpg + const int64_t output_channels_per_group = state.range(4); // Fpg + + if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); + if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); + if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); + if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); + if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); + + size_t arg_position = 5; + const auto input_shape = BenchArgsVector(state, arg_position, rank); + const auto kernel_shape = BenchArgsVector(state, arg_position, rank); + const auto paddings = BenchArgsVector(state, arg_position, rank * 2); + const auto strides = BenchArgsVector(state, arg_position, rank); + const auto dilations = BenchArgsVector(state, arg_position, rank); + + // do not check the size of each vector as they are forced from args. + if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all input image dim must > 0"); + } + + if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all kernel dim must > 0"); + } + + if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all strides dim must > 0"); + } + + if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all dilations dim must > 0"); + } + + const int64_t GC = groups * input_channels_per_group; + const int64_t GF = groups * output_channels_per_group; + std::vector x_shape = {batch_size, GC}; + x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); + std::vector f_shape = {GF, input_channels_per_group}; + f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); + + std::vector output_shape((size_t)rank); + for (int64_t i = 0; i < rank; ++i) { + auto km = 1 + dilations[i] * (kernel_shape[i] - 1); + output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1; + } + std::vector y_shape = {batch_size, GF}; + y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize = 0; + MlasConvPrepare(&Parameters, + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + static_cast(input_channels_per_group), + input_shape.data(), + kernel_shape.data(), + dilations.data(), + paddings.data(), + strides.data(), + output_shape.data(), + static_cast(output_channels_per_group), + &activation, + &WorkingBufferSize, + 0.0f, + tp); + + auto X = RandomVectorUniform(x_shape, -2.0, 2.0); + auto F = RandomVectorUniform(f_shape, -1.0, 1.0); + int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); + std::vector Y(static_cast(y_size)); + std::vector working_buffer(WorkingBufferSize); + + // warm up first round. + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + + for (auto _ : state) { + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + } +} + static void ResNet50(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); @@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); +BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); diff --git a/tests/bench/bench_sgemm.cpp b/tests/bench/bench_sgemm.cpp index a94d33c..422fc6f 100644 --- a/tests/bench/bench_sgemm.cpp +++ b/tests/bench/bench_sgemm.cpp @@ -30,9 +30,12 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); + CBLAS_TRANSPOSE transB_enum = trans_b ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transA_enum = trans_a ? CblasTrans : CblasNoTrans; + + size_t pack_b_size = MlasGemmPackBSize(transA_enum, transB_enum, N, K); std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); + MlasGemmPackB(transA_enum, transB_enum, N, K, B.data(), N, B_packed.data()); MlasGemm( trans_a ? CblasTrans : CblasNoTrans, diff --git a/tests/bench/bench_util.h b/tests/bench/bench_util.h index f96dd5c..e6eda2b 100644 --- a/tests/bench/bench_util.h +++ b/tests/bench/bench_util.h @@ -8,8 +8,12 @@ #include #include +#include "core/common/float16.h" +#include "mlas.h" + template -std::vector RandomVectorUniform( +typename std::enable_if_t, std::vector> +RandomVectorUniform( size_t N, ElementType min_value = std::numeric_limits::lowest(), ElementType max_value = std::numeric_limits::max()) { @@ -26,6 +30,25 @@ std::vector RandomVectorUniform( return r; } +template +typename std::enable_if_t, std::vector> +RandomVectorUniform( + size_t N, + ElementType min_value, + ElementType max_value) { + if (min_value.ToFloat() >= max_value.ToFloat()) { + return std::vector(N, min_value); + } + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(min_value.ToFloat(), max_value.ToFloat()); + + std::vector r(N); + for (size_t i = 0; i < N; i++) { + r[i] = ElementType(distribution(generator)); + } + return r; +} + std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index a61df1b..764355a 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -1,6 +1,6 @@ find_package(GTest) -add_executable(mlas_unittest test_activation.cpp +onnxruntime_add_executable(mlas_unittest test_activation.cpp test_blkq8.cpp test_blockq4.cpp test_conv2d.cpp @@ -26,6 +26,7 @@ test_sbgemm.cpp test_scaleoutput.cpp test_softmax.cpp test_sqnbitgemm.cpp +test_hqnbitgemm_neon.cpp test_symm_qgemm.cpp test_transpose.cpp) if(MSVC) @@ -46,7 +47,7 @@ if(IOS) XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" ) endif() -target_include_directories(mlas_unittest PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} +target_include_directories(mlas_unittest PRIVATE ${ONNXRUNTIME_ROOT}/lib ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(mlas_unittest PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS}) if(NOT MLAS_NO_ONNXRUNTIME) diff --git a/tests/unittest/matrix_buffer.h b/tests/unittest/matrix_buffer.h new file mode 100644 index 0000000..0af513e --- /dev/null +++ b/tests/unittest/matrix_buffer.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include // For malloc, free, abort +#include // For size_t +#include // For std::function +#include // For std::fill_n +#include // For std::bad_alloc (alternative to abort) + +template +class MatrixGuardBuffer { +public: + MatrixGuardBuffer() : + buffer_(nullptr), + elements_allocated_(0) { + } + + ~MatrixGuardBuffer() { + ReleaseBuffer(); + } + + // Disable copy and move semantics for simplicity + MatrixGuardBuffer(const MatrixGuardBuffer&) = delete; + MatrixGuardBuffer& operator=(const MatrixGuardBuffer&) = delete; + MatrixGuardBuffer(MatrixGuardBuffer&&) = delete; + MatrixGuardBuffer& operator=(MatrixGuardBuffer&&) = delete; + + T* GetFilledBuffer(size_t elements, const std::function& fill_func) { + if (elements == 0) { + ReleaseBuffer(); + return nullptr; + } + + if (elements > elements_allocated_) { + ReleaseBuffer(); + + size_t bytes_to_allocate = elements * sizeof(T); + if (elements != 0 && bytes_to_allocate / elements != sizeof(T)) { // Check for overflow before multiplication + // Handle overflow, e.g., by aborting or throwing + abort(); + } +#ifdef _WIN32 + buffer_ = static_cast(_aligned_malloc(bytes_to_allocate, 64)); +#else + buffer_ = static_cast(std::aligned_alloc(64, bytes_to_allocate)); +#endif + + if (buffer_ == nullptr) { + // Consider `throw std::bad_alloc();` for C++ style error handling. + abort(); + } + elements_allocated_ = elements; + } + + if (fill_func && buffer_ != nullptr) { + fill_func(buffer_, elements); + } else if (buffer_ == nullptr && elements > 0) { + abort(); // Should not happen if allocation failure aborts + } + + return buffer_; + } + + T* GetBuffer(size_t elements, bool zero_fill = false) { + if (zero_fill) { + return GetFilledBuffer( + elements, + [](T* start, size_t count) { + if (start && count > 0) { + std::fill_n(start, count, T{}); // Value-initialize + } + }); + } + + return GetFilledBuffer( + elements, + [](T* start, size_t count) { + constexpr float offset = -21.f; + constexpr float range = 43.f; + + // The following value will be used in most GEMM/CONV tests. Because this value is an integer that is + // small enough, all the floating point operations will generate exact values instead of approximate + // values. + float FillValue = 11.f; + T* FillAddress = start; + for (size_t i = 0; i < count; i++) { + auto itemv = FillValue - offset; + *FillAddress++ = (T)(itemv); + + FillValue += 7.f; + FillValue = FillValue >= range ? FillValue - range : FillValue; + } + }); + } + + void ReleaseBuffer() { + if (buffer_ != nullptr) { + #if defined(_WIN32) + _aligned_free(buffer_); + #else + free(buffer_); + #endif + buffer_ = nullptr; + } + elements_allocated_ = 0; + } + +private: + T* buffer_; + size_t elements_allocated_; +}; \ No newline at end of file diff --git a/tests/unittest/test_blockq4.cpp b/tests/unittest/test_blockq4.cpp index f75002f..7e8f128 100644 --- a/tests/unittest/test_blockq4.cpp +++ b/tests/unittest/test_blockq4.cpp @@ -18,192 +18,163 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" +#include "mlasi.h" +template +int GetElem(int v, int idx) { + return (v >> (qbits * idx)) & ((1 << qbits) - 1); +} + +template +int SetElem(int v, int idx, int value) { + v &= ~(((1 << qbits) - 1) << (qbits * idx)); + v |= (value & ((1 << qbits) - 1)) << (qbits * idx); + return v; +} + +template class MlasBlockwiseQdqTest : public MlasTestBase { private: - MatrixGuardBuffer FpBuf; - MatrixGuardBuffer FpBuf2; + std::random_device rd; + std::mt19937 gen{192373}; + std::uniform_int_distribution dist_int{0, (1 << qbits) - 1}; + std::uniform_real_distribution dist_float{-1.f, 1.f}; + MatrixGuardBuffer FpBuf; + MatrixGuardBuffer FpBuf2; + MatrixGuardBuffer FpBuf3; MatrixGuardBuffer InputElements; - MatrixGuardBuffer InputScales; + MatrixGuardBuffer InputScales; MatrixGuardBuffer InputOffsets; - MatrixGuardBuffer OutputElements; - MatrixGuardBuffer OutputScales; - MatrixGuardBuffer OutputOffsets; MatrixGuardBuffer QDQOutputElements; - MatrixGuardBuffer QDQOutputScales; + MatrixGuardBuffer QDQOutputScales; MatrixGuardBuffer QDQOutputOffsets; MatrixGuardBuffer QDQTransposedOutputElements; - MatrixGuardBuffer QDQTransposedOutputScales; + MatrixGuardBuffer QDQTransposedOutputScales; MatrixGuardBuffer QDQTransposedOutputOffsets; + constexpr static float err_ = qbits == 8 ? 1e-2f : qbits == 4 ? 6e-2f + : 2e-1f; + constexpr static float rel_ = qbits == 8 ? 5e-2f : qbits == 4 ? 2e-1f + : 5e-1f; + + bool FloatEqual(T a, T b, float err = err_, float rel = rel_) { + float va = static_cast(a); + float vb = static_cast(b); + return std::abs(va - vb) < err + std::abs(va) * rel; + } void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { - float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); - float* transposed = FpBuf2.GetBuffer(rows * columns, true); + constexpr int packSize = 8 / qbits; + T* input = FpBuf.GetFilledBuffer(rows * columns, [this](T* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = T(this->dist_float(this->gen)); + } + }); + T* dequant = FpBuf2.GetBuffer(rows * columns, true); + T* transposed = FpBuf3.GetBuffer(rows * columns, true); size_t scale_size = (rows + block_size - 1) / block_size * columns; - size_t zp_size = (scale_size + 1) / 2; - + size_t zp_size = (scale_size + packSize - 1) / packSize; MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - - uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); - uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); - uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - - int v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - elements[idx] = (v1 << 4) | v0; - } + MlasBlockwiseQuantizedBufferSizes(block_size, columnwise, rows, columns, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); // after quantize + uint8_t* qdq_weights; + uint8_t* qdq_weights_T; + if constexpr (qbits == 4) { + qdq_weights = QDQOutputElements.GetBuffer((rows * columns + packSize - 1) / packSize, true); + qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); } - float* scales = InputScales.GetBuffer(q_scale_size); - float* qdq_scales = QDQOutputScales.GetBuffer(scale_size); - float* qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size); + T* scales = InputScales.GetBuffer(q_scale_size, true); uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); - uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); - uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); - if (zp) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - zp[idx] = (v1 << 4) | v0; - } - } + T* qdq_scales; + T* qdq_scales_T; + uint8_t* qdq_zp; + uint8_t* qdq_zp_T; + if constexpr (qbits == 4) { + qdq_scales = QDQOutputScales.GetBuffer(scale_size, true); + qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size, true); + qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); + qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, - columnwise, rows, columns, threadpool_ptr); - - MlasTranspose(dequant_buf, transposed, columns, rows); - - uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); - float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + MlasQuantizeBlockwise(elements, scales, zp, input, block_size, columnwise, rows, columns, columns, threadpool_ptr); - if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, - true, rows, columns, block_size, threadpool_ptr); + MlasDequantizeBlockwise(dequant, elements, scales, zp, block_size, + columnwise, rows, columns, threadpool_ptr); - ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + MlasTranspose(dequant, transposed, columns, rows, threadpool_ptr); - if (symmetric) { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + if constexpr (qbits == 4) { + if (columnwise) { + bool signed_quant = MlasQDQQuantizeBlockwise( + input, qdq_scales, qdq_zp, qdq_weights, true, rows, columns, block_size, threadpool_ptr); - } else { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } } } - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + for (int r = 0; r < rows; r++) { + for (int c = 0; c < columns; c++) { + int idx = r * columns + c; + ASSERT_TRUE(FloatEqual(input[idx], transposed[idx])) + << " input: " << input[idx] << ", transposed: " << transposed[idx] << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } + } + } - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + if (columnwise && qbits == 4) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += packSize) { + int idx = c * q_rows + r / packSize; + for (int l = 0; l < packSize && l + r < rows; ++l) { + ASSERT_EQ(GetElem(qdq_weights_T[idx], l), GetElem(elements[idx], l)) + << ", qdq index=[" << r + l << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } } - } - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r++) { - int idx = c * meta_rows + r; - ASSERT_EQ(o_scales[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - - if (columnwise) { - ASSERT_EQ(qdq_scales_T[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r++) { + int idx = c * meta_rows + r; + ASSERT_TRUE(FloatEqual(qdq_scales_T[idx], scales[idx])) + << ", qdq index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } - } - if (symmetric) return; - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + if (symmetric) return; + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += packSize) { + int idx = c * ((meta_rows + packSize - 1) / packSize) + r / packSize; + for (int l = 0; l < packSize && r + l < meta_rows; ++l) { + ASSERT_EQ(GetElem(qdq_zp_T[idx], l), GetElem(zp[idx], l)) + << ", qdq index=" << r + l << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } @@ -213,7 +184,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static const std::string suite_name("BlockQ4"); + static const std::string suite_name("BlockQ" + std::to_string(qbits)); return suite_name.c_str(); } @@ -263,7 +234,9 @@ class MlasBlockwiseQdqTest : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/tests/unittest/test_dequantizelinear.cpp b/tests/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000..b994981 --- /dev/null +++ b/tests/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_dynamic_qgemm.cpp b/tests/unittest/test_dynamic_qgemm.cpp new file mode 100644 index 0000000..2a5d72b --- /dev/null +++ b/tests/unittest/test_dynamic_qgemm.cpp @@ -0,0 +1,172 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +// Currently this test only applies to KleidiAI Guard against it running in any other situation +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + +#include "test_util.h" +#include "mlasi.h" // for MLAS_CPUIDINFO + +class MlasDynamicQgemmTest { + private: + MatrixGuardBuffer buffer_a; + MatrixGuardBuffer buffer_bf; + MatrixGuardBuffer buffer_bq; + MatrixGuardBuffer buffer_c; + MatrixGuardBuffer buffer_c_ref; + + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test."; + } + + // Setup buffers for holding various data + + float* A = buffer_a.GetBuffer(M * K * BatchSize); + // Buffer for holding floating point version of weight matrix + float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); + // Buffer for holding quantized version of weight matrix + int8_t* Bq = buffer_bq.GetBuffer(K * N * BatchSize); + float* C = buffer_c.GetBuffer(M * N * BatchSize); + float* CRef = buffer_c_ref.GetBuffer(M * N * BatchSize); + + // Initialize A and Bf + for (size_t i = 0; i < M * K * BatchSize; ++i) + A[i] = static_cast((rand() % 255 - 128) / 16.0f); + for (size_t i = 0; i < K * N * BatchSize; ++i) + Bf[i] = static_cast((rand() % 255 - 128) / 16.0f); + + // Quantize Bf → Bq and compute per-column scale and bias per batch + std::vector> b_scale_batches(BatchSize, std::vector(N)); + std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t n = 0; n < N; ++n) { + float min_val = Bf[b * K * N + n]; + float max_val = min_val; + for (size_t k = 1; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float scale = (max_val - min_val) / 255.0f; + if (scale < 1e-8f) scale = 1.0f; + b_scale_batches[b][n] = scale; + + for (size_t k = 0; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + int q = static_cast(std::round(v / scale)); + Bq[b * K * N + k * N + n] = static_cast(std::clamp(q, -128, 127)); + } + } + } + + // Prepare kernel parameters + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; + std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); + std::vector params(BatchSize); + + for (size_t b = 0; b < BatchSize; ++b) { + params[b].A = A + b * M * K; + params[b].lda = K; + params[b].C = C + b * M * N; + params[b].ldc = N; + // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB + void* packed_b = packed_b_storage.data() + b * MlasDynamicQgemmPackBSize(N, K); + MlasDynamicQgemmPackB(N, K, + Bq + b * K * N, + b_scale_batches[b].data(), + b_bias_batches[b].data(), + packed_b); + params[b].PackedB = packed_b; + } + + // call MlasDynamicQGemmBatch Function + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); + + // Compute reference result + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[b * M * K + m * K + k]; + float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; + sum += a * bval; + } + CRef[b * M * N + m * N + n] = sum; + } + } + } + + // Validate results + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + } + + static const char* GetTestSuiteName() { + return "DynamicQgemm"; + } +}; + +class DynamicQgemmExecuteTest : public MlasTestFixture { + public: + DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) + : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} + + void TestBody() override { + this->mlas_tester->Test(M_, N_, K_, BatchSize_); + } + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { + std::stringstream ss; + ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; + + std::string test_name = ss.str(); + + testing::RegisterTest( + MlasDynamicQgemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new DynamicQgemmExecuteTest(M, N, K, BatchSize); + }); + + return 1; + } + + static size_t RegisterAll(bool is_short_execute) { + const std::vector batch_size = is_short_execute ? std::vector{1UL, 2UL, 4UL} + : std::vector{1UL, 2UL, 4UL, 8UL, 16UL, 32UL, 64UL}; + size_t count = 0; + const size_t sizes[] = {1, 4, 8, 16, 32, 64}; + for (size_t M : sizes) + for (size_t N : sizes) + for (size_t K : sizes) + for (size_t B : batch_size) + count += RegisterSingleTest(M, N, K, B); + return count; + } + + private: + size_t M_, N_, K_, BatchSize_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); +#endif diff --git a/tests/unittest/test_eltwise.cpp b/tests/unittest/test_eltwise.cpp new file mode 100644 index 0000000..720ff37 --- /dev/null +++ b/tests/unittest/test_eltwise.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "mlasi.h" +#include "eltwise.h" + +class MlasEltwiseAddTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputLeftFp16; + MatrixGuardBuffer BufferInputRightFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] + InputRight[n]; + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N); + MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = MLAS_FP16(distribution(generator)); + InputRight[n] = MLAS_FP16(ScalarValue.value_or(distribution(generator))); + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float inLeft = InputLeft[n].ToFloat(); + float inRight = InputRight[n].ToFloat(); + float ref = inLeft + inRight; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << inLeft << ", " << inRight << ", got: " << out << ", expecting: " << ref + << ", r-diff: " << diff / std::fabs(ref); + } + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Add"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); + TestFp16(n, -17.f, 11.f, -5000.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_fgemm.h b/tests/unittest/test_fgemm.h index 2bd0941..6d70151 100644 --- a/tests/unittest/test_fgemm.h +++ b/tests/unittest/test_fgemm.h @@ -56,7 +56,7 @@ class FgemmPackedContext { } }; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template <> class FgemmPackedContext { public: @@ -112,11 +112,11 @@ class FgemmPackedContext { float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); + size_t PackedBSize = MlasGemmPackBSize(TransA, TransB, N, K); void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); std::vector data(BatchSize); for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + MlasGemmPackB(TransA, TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); data[i].BIsPacked = true; data[i].A = A + M * K * i; data[i].lda = lda; diff --git a/tests/unittest/test_fgemm_fixture.h b/tests/unittest/test_fgemm_fixture.h index 53b3eda..c832ca6 100644 --- a/tests/unittest/test_fgemm_fixture.h +++ b/tests/unittest/test_fgemm_fixture.h @@ -70,6 +70,7 @@ class FgemmShortExecuteTest : public MlasTestFixture #include "test_fp16.h" diff --git a/tests/unittest/test_hgemm_neon.cpp b/tests/unittest/test_hgemm_neon.cpp new file mode 100644 index 0000000..19d41fa --- /dev/null +++ b/tests/unittest/test_hgemm_neon.cpp @@ -0,0 +1,683 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hgemm_neon.cpp + +Abstract: + + Tests for MLAS fp16 GEMM on ARM CPU. + +--*/ + +#include +#include + +#include "test/mlas/unittest/test_util.h" +#include "mlasi.h" +#include "halfgemm.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonHGemmPackBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void PackB_TransposedB(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 32 <= N; i += 32) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 32; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + } + if (i + 16 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 16; + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void PackB_B(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 32 <= N; i += 32) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 32; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + } + for (; i + 16 <= N; i += 16) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* packed, const MLAS_FP16* ref) { + size_t j = 0; + for (; j + 31 < N; j += 32) { + for (size_t i = 0; i < 32 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 32 << " N " << j + i % 32; + } + } + for (; j + 15 < N; j += 16) { + for (size_t i = 0; i < 16 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 16 << " N " << j + i % 16; + } + } + if (j + 7 < N) { + for (size_t i = 0; i < 8 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 8 << " N " << j + i % 8; + } + j += 8; + } + if (j < N) { + for (size_t i = 0; i < K; ++i) { + for (size_t k = 0; k < N - j; ++k) { + ASSERT_EQ(packed[j * K + i * 8 + k].val, ref[j * K + i * 8 + k].val) + << " seed " << seed_ << " K " << i << " N " << j + k; + } + } + } + } + + template + void TestPackB_TransposedB() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_TransposedB_Kernel(input, packed, N, K, K); + PackB_TransposedB(input, ref); + Check(packed, ref); + } + + template + void TestPackB_B() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_B_Kernel(input, packed, N, K, N); + PackB_B(input, ref); + Check(packed, ref); + } + + public: + MlasNeonHGemmPackBTest() + : seed_(rd_()), gen_(seed_), distrib_(-100.f, 100.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackB"; + } + + void ExecuteShort(void) override { + TestPackB_TransposedB<1, 1>(); + TestPackB_TransposedB<1, 15>(); + TestPackB_TransposedB<1, 31>(); + TestPackB_TransposedB<8, 1>(); + TestPackB_TransposedB<8, 16>(); + TestPackB_TransposedB<9, 31>(); + TestPackB_TransposedB<9, 33>(); + TestPackB_TransposedB<31, 33>(); + TestPackB_TransposedB<33, 67>(); + TestPackB_TransposedB<63, 96>(); + TestPackB_TransposedB<271, 263>(); + TestPackB_B<1, 1>(); + TestPackB_B<1, 15>(); + TestPackB_B<1, 31>(); + TestPackB_B<8, 1>(); + TestPackB_B<8, 16>(); + TestPackB_B<9, 31>(); + TestPackB_B<9, 33>(); + TestPackB_B<31, 31>(); + TestPackB_B<63, 33>(); + TestPackB_B<33, 31>(); + TestPackB_B<33, 33>(); + TestPackB_B<65, 67>(); + TestPackB_B<65, 96>(); + TestPackB_B<271, 263>(); + } +}; + +class MlasNeonHGemmTransposedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * ldb + k].ToFloat()); + } + C[m * ldc + n] = MLAS_FP16(accu * alphaf + C[m * ldc + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[k], ref[k], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " N " << N << " ldc " << ldc + << " value " << C[k] << " ref " << ref[k]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ref[k] = C[k]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7); + const size_t ldb = ((K + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(ldb * N, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_TransposedB_Kernel(A, B, C, M, N, K, lda, ldb, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmTransposedBTest() + : seed_(1928375), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n + k * ldb].ToFloat()); + } + C[m * ldc + n] = MLAS_FP16(accu * alphaf + C[m * ldc + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t idx = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[idx], ref[idx], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " N " << N << " ldc " << ldc + << " value " << C[idx] << " ref " << ref[idx]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t idx = i * ldc + j; + ref[idx] = C[idx]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7); + const size_t ldb = ((N + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ldb, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_B_Kernel(A, B, C, M, N, K, lda, ldb, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmBTest() + : seed_(172387), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(1.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.f), MLAS_FP16(0.f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.f), MLAS_FP16(1.f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.f), MLAS_FP16(0.f)); + TestHGemm<1, 78, 263>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<2, 65, 65>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 63, 63>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 65, 63>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 63, 65>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + } +}; + +class MlasNeonHGemmPackedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + size_t n = 0; + for (; n + 32 <= N; n += 32) { + for (size_t i = 0; i < 32; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 32 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + for (; n + 16 <= N; n += 16) { + for (size_t i = 0; i < 16; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 16 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + if (n + 8 <= N) { + for (size_t i = 0; i < 8; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + n += 8; + } + if (n < N) { + for (size_t i = 0; i < N - n; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[k], ref[k], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " value " << C[k] << " ref " << ref[k]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ref[k] = C[k]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ((N + 7) & ~7), InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_PackedB_Kernel(A, B, C, M, N, K, lda, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmPackedBTest() + : seed_(1928372), gen_(), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[transA ? k * lda + i : i * lda + k].ToFloat()) * (B[transB ? j * ldb + k : k * ldb + j].ToFloat()); + } + C[i * ldc + j] = MLAS_FP16(accu * alphaf + C[i * ldc + j].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ASSERT_TRUE(FloatEqual(C[i * ldc + j], ref[i * ldc + j], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " value " << C[i * ldc + j] << " ref " << ref[i * ldc + j]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ref[i * ldc + j] = C[i * ldc + j]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = transA ? (M + 15) & (~15) : (K + 15) & (~15); + const size_t ldb = transB ? (K + 7) & (~7) : (N + 15) & (~15); + const size_t ldc = (N + 7) & (~7); + const auto* A = A_.GetFilledBuffer(transA ? lda * K : M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(transB ? ldb * N : K * ldb, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + MlasGemm(transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans, + M, N, K, A, lda, B, ldb, C, ldc, alpha.val, beta.val, nullptr); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmTest() + : seed_(192837), gen_(seed_), distrib_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemm"; + } + + // TODO(fajin): test beta + void ExecuteShort(void) override { + TestHGemm<2, 1, 1, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<129, 513, 1025, false, false>(MLAS_FP16(0.5f), MLAS_FP16(0.5f)); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_hqnbitgemm_neon.cpp b/tests/unittest/test_hqnbitgemm_neon.cpp new file mode 100644 index 0000000..946eb67 --- /dev/null +++ b/tests/unittest/test_hqnbitgemm_neon.cpp @@ -0,0 +1,501 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hqnbitgemm_neon.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include +#include + +#include "test_util.h" +#include "mlasi.h" +#include "qnbitgemm.h" +#include "mlas_qnbit.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + MatrixGuardBuffer fp32Buffer_; + MatrixGuardBuffer fp16Buffer_; + + template + void TestFp16ToFp32() { + const auto* src = fp16Buffer_.GetFilledBuffer(count, [](unsigned short* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = static_cast(i); + } + }); + auto* dest = fp32Buffer_.GetBuffer(count, true); + + MlasCastF16ToF32KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + template + void TestFp32ToFp16() { + const auto* src = fp32Buffer_.GetFilledBuffer(count, [](float* p, size_t size) { + for (size_t i = 0; i < size; i++) { + p[i] = static_cast(i) + 0.125f; + } + }); + auto* dest = fp16Buffer_.GetBuffer(count, true); + + MlasCastF32ToF16KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32<(1 << 16)>(); + TestFp16ToFp32<1>(); + TestFp16ToFp32<4>(); + TestFp16ToFp32<7>(); + TestFp32ToFp16<(1 << 16)>(); + TestFp32ToFp16<3>(); + TestFp32ToFp16<4>(); + TestFp32ToFp16<6>(); + } +}; + +class MlasNeonFp16PrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void Transpose8x8(const uint8_t* src, size_t n, size_t k, uint8_t* dst) { + for (size_t c = 0; c < 8; c++) { + for (size_t r = 0; r < 8; r++) { + size_t i = (n + c) * Ldb + r + k; + size_t j = n * Ldb + (r + k) * 8 + c; + dst[j] = src[i]; + } + } + } + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + MLAS_FORCEINLINE + void PrepackSlice(const uint8_t* src, size_t j, uint8_t* dst) { + for (size_t i = 0; i < 8; i++) { + uint8_t v0 = GetInt4(src[j + (i >> 1)], i); + uint8_t v1 = GetInt4(src[j + ((8 + i) >> 1)], i + 8); + dst[j + i] = v0 | (v1 << 4); + } + } + + template + MLAS_FORCEINLINE void Prepack(const uint8_t* src, uint8_t* dst) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t k = 0; k < Ldb; k += 8) { + Transpose8x8(src, n, k, dst); + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < Ldb; k += 8) { + PrepackSlice(src, n * Ldb + k, dst); + } + } + } + + template + MLAS_FORCEINLINE void Check(const uint8_t* packed, const uint8_t* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; i += 2) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(packed[n * Ldb + (i >> 1) * 8 + j], ref[n * Ldb + (i >> 1) * 8 + j]) + << " seed " << seed_ + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; i += 2) { + ASSERT_EQ(packed[n * Ldb + (i >> 1)], ref[n * Ldb + (i >> 1)]) + << " seed " << seed_ + << " n " << n << " i " << i; + } + } + } + + template + void TestPrepack() { + constexpr size_t Bits = 4; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t BufferSize = N * Ldb; + auto InitializeBuffer = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BufferSize, InitializeBuffer); + auto* packed = packed_.GetBuffer(BufferSize, true); + auto* ref = ref_.GetBuffer(BufferSize, true); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::HQNBIT_CompFp16, input, packed, + nullptr, false, nullptr, nullptr); + Prepack(input, ref); + Check(packed, ref); + } + + public: + MlasNeonFp16PrepackTest() + : seed_(19287), gen_(seed_), distrib_(0, 255) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16Prepack"; + } + + void ExecuteShort(void) override { + TestPrepack<1, 1, 16>(); + TestPrepack<1, 15, 16>(); + TestPrepack<1, 31, 16>(); + TestPrepack<8, 1, 16>(); + TestPrepack<8, 16, 16>(); + TestPrepack<9, 31, 16>(); + TestPrepack<9, 33, 32>(); + TestPrepack<15, 33, 16>(); + TestPrepack<17, 67, 16>(); + TestPrepack<17, 96, 128>(); + TestPrepack<263, 263, 16>(); + } +}; + +class MlasNeonFp16DequantBTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + std::uniform_real_distribution _distribFp; + MatrixGuardBuffer input_, zero_points_; + MatrixGuardBuffer dequant_, ref_, scales_; + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + template + void DequantB(const uint8_t* src, MLAS_FP16* dst, const MLAS_FP16* scales, const uint8_t* zero_points) { + constexpr size_t blkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ld_src = (blkNum * BlkLen + 1) / 2; + constexpr size_t ld_dst = blkNum * BlkLen; + constexpr size_t ld_zp = (blkNum + 1) / 2; + size_t n = 0; + for (; n + 8 <= N; n += 8) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + for (size_t i = 0; i < BlkLen; i += 2, i_dst += 8) { + for (size_t j = 0; j < 8; ++j, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp + ld_zp * j], blk) : 8); + float scale = scales[i_scale + blkNum * j]; + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + + for (; n < N; ++n) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp], blk) : 8); + float scale = scales[i_scale]; + for (size_t i = 0; i < BlkLen; i += 16, i_dst += 8) { + for (size_t j = 0; j < 16; j += 2, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = std::abs(v0.ToFloat()), f1 = std::abs(v1.ToFloat()); + return std::abs(f0 - f1) <= f1 * rtol + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < 8; ++j) { + size_t idx = n * Ldb + i * 8 + j; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; ++i) { + size_t idx = n * Ldb + i; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i; + } + } + } + + template + void TestDequant() { + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t BCount = BlkNum * BlkLen * N; + constexpr size_t ScaleCount = N * BlkNum; + constexpr size_t ZpSize = N * ((BlkNum + 1) / 2); + + auto InitializeBuffer_i8 = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + auto InitializeBuffer_fp16 = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(_distribFp(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BCount / 2, InitializeBuffer_i8); + const auto* zero_points = zero_points_.GetFilledBuffer(ZpSize, InitializeBuffer_i8); + auto* dequant = dequant_.GetBuffer(BCount); + auto* ref = ref_.GetBuffer(BCount); + const auto* scales = scales_.GetFilledBuffer(ScaleCount, InitializeBuffer_fp16); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant, reinterpret_cast(input), scales, + UseZeroPoints ? reinterpret_cast(zero_points) : nullptr, + N, K, BlkNum); + DequantB(input, ref, scales, zero_points); + Check(dequant, ref); + } + + public: + MlasNeonFp16DequantBTest() + : seed_(19287), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16DequantB"; + } + + void ExecuteShort(void) override { + TestDequant<1, 1, 16, false>(); + TestDequant<1, 1, 16, true>(); + TestDequant<1, 15, 16, false>(); + TestDequant<1, 15, 16, true>(); + TestDequant<1, 31, 16, false>(); + TestDequant<1, 31, 16, true>(); + TestDequant<8, 1, 16, false>(); + TestDequant<8, 1, 16, true>(); + TestDequant<8, 16, 16, false>(); + TestDequant<8, 16, 16, true>(); + TestDequant<9, 31, 16, false>(); + TestDequant<9, 31, 16, true>(); + TestDequant<9, 33, 32, false>(); + TestDequant<9, 33, 32, true>(); + TestDequant<15, 33, 16, false>(); + TestDequant<15, 33, 16, true>(); + TestDequant<17, 67, 16, false>(); + TestDequant<17, 67, 16, true>(); + TestDequant<17, 96, 128, false>(); + TestDequant<17, 96, 128, true>(); + TestDequant<263, 263, 16, false>(); + TestDequant<263, 263, 16, true>(); + } +}; + +class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + MatrixGuardBuffer A_, B_, C_, ref_, bias_; + + MLAS_FORCEINLINE + void InitializeBuffer(MLAS_FP16* buffer, float min, float max, size_t count) { + std::uniform_real_distribution distrib(min, max); + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib(gen_)); + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + float GetBVal(const MLAS_FP16* B, size_t n, size_t k) { + size_t i; + if ((N & (~7)) > n) { + size_t full8 = n & (~7); + i = full8 * ldb + 8 * k + (n - full8); + } else { + i = n * ldb + k; + } + return B[i].ToFloat(); + } + + template + void MatMul(const MLAS_FP16* A, const MLAS_FP16* B, const MLAS_FP16* bias, MLAS_FP16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = UseBias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * K + k].ToFloat(); + float b = GetBVal(B, n, k); + accu = accu + a * b; + } + C[m * N + n] = MLAS_FP16(accu); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * Ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestHQ4BitGemmKernel() { + static_assert(M <= 2); + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkNum * BlkLen; + + const auto* A = A_.GetFilledBuffer(M * K, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + const auto* B = B_.GetFilledBuffer(ldb * N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + auto* bias = bias_.GetFilledBuffer(N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -5.0f, 5.0f, t); + }); + + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + A, B, UseBias ? bias : nullptr, C, M, N, K, K, ldb, N); + + MatMul(A, B, bias, ref); + Check(C, ref); + } + + public: + MlasNeonFp16HQ4BitGemmKernelTest() + : seed_(19287), gen_(seed_) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16HQ4BitGemmKernel"; + } + + template + void ExecuteShort_T(void) { + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + } + + void ExecuteShort(void) override { + ExecuteShort_T<1>(); + ExecuteShort_T<2>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + if (GetMlasPlatform().QNBitGemmDispatch) { + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmPackQuantBData) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + } + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_rope.cpp b/tests/unittest/test_rope.cpp new file mode 100644 index 0000000..3dd6e5e --- /dev/null +++ b/tests/unittest/test_rope.cpp @@ -0,0 +1,140 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope.h + +Abstract: + + Tests for MLAS RoPE. + +--*/ + +#include "test_util.h" +#include "mlasi.h" +#include "rotary_embedding.h" + +using namespace onnxruntime; + +template +class MlasRoPETest : public MlasTestBase { + public: + void Test(size_t rotary_emb_dim, bool interleaved) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1.0f); + } + for (size_t i = 0; i < table_len; ++i) { + // https://arxiv.org/pdf/2104.09864 section 3.4.3 + float theta_i = static_cast(pow(10000, -2.0f * i / rotary_emb_dim)); + sin_data[i] = static_cast(std::sin(theta_i)); + cos_data[i] = static_cast(std::cos(theta_i)); + } + + // Call the function + MlasRotaryEmbedOneRow_FallBack(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_ref[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i])) + << "Expected: " << output_ref[i] << " Actual: " << output_impl[i] << "@[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class RoPEShortExecuteTest : public MlasTestFixture> { + public: + explicit RoPEShortExecuteTest(size_t rotary_emb_dim, bool interleaved) + : rotary_emb_dim_(rotary_emb_dim), + interleaved_(interleaved) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(rotary_emb_dim_, interleaved_); + } + + static size_t RegisterSingleTest(size_t rotary_emb_dim, bool interleaved) { + size_t tests_registered = 0; + + std::string test_suite_name{"RoPE_"}; + if (std::is_same::value) { + test_suite_name += "fp32"; + } else if (std::is_same::value) { + test_suite_name += "fp16"; + } else { + ADD_FAILURE() << "Unknown type passed to test: " << test_suite_name; + return 0; // Return 0 since no test is registered + } + + std::stringstream ss; + ss << "/rotary_emb_dim" << rotary_emb_dim << "/interleaved" << interleaved; + auto test_name = ss.str(); + + testing::RegisterTest( + test_suite_name.c_str(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new RoPEShortExecuteTest(rotary_emb_dim, interleaved); + }); + + tests_registered += 1; + + return tests_registered; + } + + static size_t RegisterShortExecuteTests() { + size_t tests_registered = 0; + tests_registered += RegisterSingleTest(6, false); + tests_registered += RegisterSingleTest(6, true); + tests_registered += RegisterSingleTest(16, false); + tests_registered += RegisterSingleTest(16, true); + tests_registered += RegisterSingleTest(24, false); + tests_registered += RegisterSingleTest(24, true); + tests_registered += RegisterSingleTest(32, false); + tests_registered += RegisterSingleTest(32, true); + tests_registered += RegisterSingleTest(42, false); + tests_registered += RegisterSingleTest(42, true); + tests_registered += RegisterSingleTest(64, false); + tests_registered += RegisterSingleTest(64, true); + tests_registered += RegisterSingleTest(70, false); + tests_registered += RegisterSingleTest(70, true); + return tests_registered; + } + + private: + size_t rotary_emb_dim_; + bool interleaved_; +}; + +// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +#ifdef MLAS_TARGET_AMD64 +static size_t RoPERegisterAllShortExecuteTests() { + return RoPEShortExecuteTest::RegisterShortExecuteTests() + RoPEShortExecuteTest::RegisterShortExecuteTests(); +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return RoPERegisterAllShortExecuteTests(); + } + return 0; + }); +#endif diff --git a/tests/unittest/test_scaleoutput.cpp b/tests/unittest/test_scaleoutput.cpp index 34f1784..0b58e48 100644 --- a/tests/unittest/test_scaleoutput.cpp +++ b/tests/unittest/test_scaleoutput.cpp @@ -22,7 +22,7 @@ class MlasScaleOutputTest : public MlasTestBase { std::numeric_limits::max()); for (size_t s = 0; s < M * N; s++) { - Input[s] = int_distribution(generator); + Input[s] = int_distribution(generator); // It could be zero Output[s] = OutputRef[s] = real_distribution(generator); } @@ -52,10 +52,14 @@ class MlasScaleOutputTest : public MlasTestBase { constexpr float epsilon = 1e-6f; for (size_t n = 0; n < M * N; n++) { - float diff = std::fabs((Output[n] - OutputRef[n]) / OutputRef[n]); + float outvalue = OutputRef[n]; // When `AccumulateMode` is false, there is a high chance that this value could be zero + float diff = std::fabs(Output[n] - outvalue); + if (outvalue != 0) { + diff /= outvalue; + } ASSERT_LE(diff, epsilon) << " @[" << n / N << "," << n % N << "], total:[" << M << "," << N << "], got:" - << Output[n] << ", expecting:" << OutputRef[n]; + << Output[n] << ", expecting:" << outvalue; } } diff --git a/tests/unittest/test_softcap.cpp b/tests/unittest/test_softcap.cpp new file mode 100644 index 0000000..3ad1d06 --- /dev/null +++ b/tests/unittest/test_softcap.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "mlasi.h" +#include "softmax.h" + +class MlasComputeTanhTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeTanh(Input, Output, N); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::tanh(in); + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref + << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Tanh"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -3.51562f, 3.51562f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +class MlasComputeSoftcapTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestFp16(size_t N, float MinimumValue, float MaximumValue, float cap) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeSoftcap(Input, Output, N, MLAS_FP16(cap)); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::tanh(in / cap) * cap; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Softcap"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -10.f, 10.f, 3.2f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_softmax.cpp b/tests/unittest/test_softmax.cpp index fb4ebbe..f07a0c3 100644 --- a/tests/unittest/test_softmax.cpp +++ b/tests/unittest/test_softmax.cpp @@ -2,6 +2,126 @@ // Licensed under the MIT License. #include "test_util.h" +#include "mlasi.h" +#include "softmax.h" + +class MlasComputeExpTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue) { + float* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = distribution(generator); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = std::exp(Input[n]); + } + + MlasComputeExp(Input, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeExp(Input, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::exp(in); + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff: " << diff / std::fabs(ref); + } + } + + void TestSumFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + float max_val = std::numeric_limits::lowest(); + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + max_val = std::fmax(max_val, Input[n].ToFloat()); + } + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + auto sum = dispatch->SumExp_Fp16(Input, Output, N, MLAS_FP16(-max_val)); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + float sum_ref = 0.0f; + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::exp(in - max_val); + sum_ref += ref; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff: " << diff / std::fabs(ref); + } + + float diff = std::fabs(sum.ToFloat() - sum_ref); + ASSERT_TRUE(diff <= 1e-3f || diff <= std::fabs(sum_ref) * 5e-3f) + << " sum: " << sum.ToFloat() << ", expecting: " << sum_ref << ", r-diff: " << diff / std::fabs(sum_ref); + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Exp"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); + TestSumFp16(n, -10.f, 10.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; template class MlasSoftmaxTest : public MlasTestBase { @@ -9,6 +129,8 @@ class MlasSoftmaxTest : public MlasTestBase { MatrixGuardBuffer BufferInput; MatrixGuardBuffer BufferOutput; MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; MLAS_THREADPOOL* threadpool_; void Test(size_t N, size_t D, float MinimumValue, float MaximumValue) { @@ -30,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -44,6 +166,64 @@ class MlasSoftmaxTest : public MlasTestBase { } } +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestReduceMaxFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + float ref = std::numeric_limits::lowest(); + + for (size_t nd = 0; nd < N; nd++) { + Input[nd] = MLAS_FP16(distribution(generator)); + ref = std::fmax(ref, Input[nd].ToFloat()); + } + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + auto out = dispatch->ReduceMax_Fp16(Input, N).ToFloat(); + + constexpr float AbsoluteTolerance = 1e-3f; + constexpr float RelativeTolerance = 1e-3f; + + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << "ReduceMaxFp16: " << N << ", got: " << out << ", expecting: " << ref + << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + + void TestFp16(size_t N, size_t D, float MinimumValue, float MaximumValue, bool LogSoftmax, bool SmoothSoftmax) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N * D); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N * D); + float* InputReference = BufferInput.GetBuffer(N * D); + float* OutputReference = BufferOutputReference.GetBuffer(N * D); + + std::default_random_engine generator(static_cast(N * D)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t nd = 0; nd < N * D; nd++) { + Input[nd] = MLAS_FP16(distribution(generator)); + InputReference[nd] = Input[nd].ToFloat(); + } + + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); + ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t nd = 0; nd < N * D; nd++) { + float in = Input[nd].ToFloat(); + float ref = OutputReference[nd]; + float out = Output[nd].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << "LogSoftmax:" << LogSoftmax << ", SmoothSoftmax: " << SmoothSoftmax << ", input " << in + << ", got: " << out << ", expecting: " << ref << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { for (size_t n = 0; n < N; n++) { float MaximumValue = std::numeric_limits::lowest(); @@ -99,11 +279,32 @@ class MlasSoftmaxTest : public MlasTestBase { void ExecuteShort(void) override { for (size_t d = 1; d < 128; d++) { Test(1, d, -10.f, 10.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestReduceMaxFp16(d, -10.f, 10.f); + TestFp16(1, d, -10.f, 10.f, false, true); + TestFp16(1, d, -10.f, 10.f, true, true); + TestFp16(1, d, -10.f, 10.f, false, false); + TestFp16(1, d, -10.f, 10.f, true, false); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) } Test(3, 128, 20.f, 30.f); Test(63, 95, -150.f, 190.f); Test(16, 211, 20.f, 30.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(3, 128, 3.f, 7.f, false, true); + TestFp16(3, 128, 3.f, 7.f, true, true); + TestFp16(3, 128, 3.f, 7.f, false, false); + TestFp16(3, 128, 3.f, 7.f, true, false); + TestFp16(63, 95, -15.f, 19.f, false, true); + TestFp16(63, 95, -15.f, 19.f, true, true); + TestFp16(63, 95, -15.f, 19.f, false, false); + TestFp16(63, 95, -15.f, 19.f, true, false); + TestFp16(16, 211, -7.f, -3.f, false, true); + TestFp16(16, 211, -7.f, -3.f, true, true); + TestFp16(16, 211, -7.f, -3.f, false, false); + TestFp16(16, 211, -7.f, -3.f, true, false); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) } }; @@ -111,6 +312,7 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); if (GetMlasThreadPool() != nullptr) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } diff --git a/tests/unittest/test_sq8bitgemm.cpp b/tests/unittest/test_sq8bitgemm.cpp new file mode 100644 index 0000000..f24ea85 --- /dev/null +++ b/tests/unittest/test_sq8bitgemm.cpp @@ -0,0 +1,869 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sq8bitgemm_neon.cpp + +Abstract: + + Tests for MatMul8Bits kernels on x86 CPU with input A type T1 fp32. + +--*/ + +#include +#include + +#include "test_util.h" +#include "mlasi.h" +#include "mlas_q4.h" +#include "qnbitgemm.h" +#include "mlas_qnbit.h" + +class MlasSQ8BitPrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; + MatrixGuardBuffer inputScale_, refScale_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_, refBlkUnsignedQuantAZeroPointCorrection_; + +#ifdef MLAS_TARGET_ARM64 + template + void PrepackB(const uint8_t* src, uint8_t* dst, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n * ldb + k; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n * BlkCount + k; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n * ldb + k; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n * BlkCount + k; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + } +#else // not MLAS_TARGET_ARM64 + template + void PrepackB(const uint8_t* src, uint8_t* dst, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + SubBlkLen <= ldb; k += SubBlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k, src + (n + i) * ldb + k + SubBlkLen, dst + n * ldb + 4 * k + i * SubBlkLen); + } + } + + for (size_t kk = 0; kk + k + BlkLen <= ldb; kk += BlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k + kk, src + (n + i) * ldb + k + kk + BlkLen, dst + n * ldb + 4 * k + 4 * kk + i * BlkLen); + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + std::copy(src + n * ldb, src + n * ldb + ldb, dst + n * ldb); + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + BlkPerSubBlk <= BlkCount; k += BlkPerSubBlk) { + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < BlkPerSubBlk; ++j) { + auto srcOffset = (n + i) * BlkCount + k + j; + auto scaleDstOffset = n * BlkCount + 4 * k + i * BlkPerSubBlk + j; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + j) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + for (size_t kk = 0; k + kk < BlkCount; ++kk) { + for (size_t i = 0; i < 4; ++i) { + auto srcOffset = (n + i) * BlkCount + k + kk; + auto scaleDstOffset = n * BlkCount + 4 * k + 4 * kk + i; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + kk) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + auto srcOffset = n * BlkCount + k; + auto scaleDstOffset = n * BlkCount + k; + auto sumDstOffset = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0, N4 = N & (~3), ldbSub = ldb & (~(SubBlkLen - 1)); + for (; n < N4; ++n) { + size_t k = 0; + for (; k < ldbSub && k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(SubBlkLen - 1))) * 4 + (n & 3) * SubBlkLen + (k & (SubBlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + for (; k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(BlkLen - 1))) * 4 + (n & 3) * BlkLen + (k & (BlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + ASSERT_EQ(packedB[n * ldb + k], refB[n * ldb + k]) + << " n " << n << " k " << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + size_t n = 0, N4 = N & (~3), BlkCountSub = BlkCount & (~(BlkPerSubBlk - 1)); + + for (; n < N4; ++n) { + size_t k = 0; + for (; k < BlkCountSub; ++k) { + size_t idx = (n & (~3)) * BlkCount + (k & (~(BlkPerSubBlk - 1))) * 4 + (n & 3) * BlkPerSubBlk + (k & (BlkPerSubBlk - 1)); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + for (; k < BlkCount; ++k) { + size_t idx = (n & (~3)) * BlkCount + k * 4 + (n & 3); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + ASSERT_EQ(packedScale[n * BlkCount + k], refScale[n * BlkCount + k]) + << " n " << n << " k " << k; + } + } + } +#endif // MLAS_TARGET_ARM64 + + template + void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { + if (refBlkSum == nullptr) { + return; + } + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + + for (size_t n = 0; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + ASSERT_EQ(packedBlkSum[idx], refBlkSum[idx]) + << " n " << n << " k " << k; + } + } + } + + template + void TestPrepack() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackBCount = N * Ldb; + constexpr size_t ScaleCount = BlkCount * N; + const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + + const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }); + + const auto* inputScale = inputScale_.GetFilledBuffer(ScaleCount, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + const auto* inputZp = hasZp ? inputZp_.GetFilledBuffer(ScaleCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }) + : nullptr; + + auto* packedBuffer = packedBuffer_.GetBuffer(BufferSize, true); + auto* refB = refB_.GetBuffer(PackBCount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + auto* refBlkUnsignedQuantAZeroPointCorrection = isQuantAUnsigned ? refBlkUnsignedQuantAZeroPointCorrection_.GetBuffer(((N + 15) & (~15)) * BlkCount, true) : nullptr; + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); + + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, hasZp, inputZp, nullptr); + + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, hasZp, nullptr, nullptr); + + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, hasZp, inputZp, nullptr); + + PrepackB(inputB, refB, refBlkUnsignedQuantAZeroPointCorrection); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum, refBlkUnsignedQuantAZeroPointCorrection); + + CheckB(reinterpret_cast(packedQuantB.PackedQuantBData), refB); + CheckScale(packedQuantB.PackedQuantBScale, refScale); + CheckBlkSum(packedQuantB.QuantBBlkSum, refBlkSum); + CheckBlkSum(packedQuantB.BlkUnsignedQuantAZeroPointCorrection, refBlkUnsignedQuantAZeroPointCorrection); + } + + public: + MlasSQ8BitPrepackTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitPrepack"; + } + + template + void Execute(void) { + TestPrepack(); + TestPrepack(); + } + + void ExecuteShort(void) override { + auto& platform = GetMlasPlatform(); + + if (platform.Avx512Supported_) { + Execute<1, 1, 16, 128>(); + Execute<1, 1, 32, 128>(); + Execute<1, 1, 64, 128>(); + Execute<1, 1, 128, 128>(); + Execute<1, 1, 256, 128>(); + + Execute<16, 4, 16, 128>(); + Execute<32, 4, 16, 128>(); + Execute<64, 4, 16, 128>(); + Execute<128, 4, 16, 128>(); + + Execute<15, 5, 16, 128>(); + Execute<15, 5, 32, 128>(); + Execute<15, 5, 64, 128>(); + Execute<15, 5, 128, 128>(); + Execute<15, 5, 256, 128>(); + + Execute<17, 8, 16, 128>(); + Execute<17, 8, 32, 128>(); + Execute<17, 8, 64, 128>(); + Execute<17, 8, 128, 128>(); + Execute<17, 8, 256, 128>(); + + Execute<256, 16, 16, 128>(); + Execute<257, 17, 32, 128>(); + Execute<255, 15, 64, 128>(); + Execute<256, 17, 128, 128>(); + Execute<257, 16, 256, 128>(); + } else { + Execute<1, 1, 16, 64>(); + Execute<1, 1, 32, 64>(); + Execute<1, 1, 64, 64>(); + Execute<1, 1, 128, 64>(); + Execute<1, 1, 256, 64>(); + + Execute<16, 4, 16, 64>(); + Execute<32, 8, 16, 64>(); + Execute<64, 12, 32, 64>(); + Execute<128, 16, 64, 64>(); + + Execute<15, 3, 16, 64>(); + Execute<15, 4, 32, 64>(); + Execute<15, 5, 64, 64>(); + Execute<15, 6, 128, 64>(); + Execute<15, 7, 256, 64>(); + Execute<15, 8, 16, 64>(); + Execute<15, 9, 16, 64>(); + + Execute<17, 3, 16, 64>(); + Execute<17, 4, 32, 64>(); + Execute<17, 5, 64, 64>(); + Execute<17, 6, 128, 64>(); + Execute<17, 7, 256, 64>(); + Execute<17, 8, 16, 64>(); + Execute<17, 9, 16, 64>(); + + Execute<159, 16, 16, 64>(); + Execute<160, 17, 32, 64>(); + Execute<161, 15, 64, 64>(); + Execute<160, 17, 128, 64>(); + Execute<159, 16, 256, 64>(); + Execute<3072, 128, 16, 64>(); + } + } +}; + +class MlasSQ8BitQuantAKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer workspace_, refQuantA_; + MatrixGuardBuffer inputA_, refScale_, refBlkSum_; + + template + void QuantA(const float* inputA, uint8_t* quantA, float* scalePtr, float* blkSum, bool quantAUnsigned) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t input_lda = K; + + constexpr size_t Bits = 8; + constexpr size_t output_lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + float vAbsMax = 0.f; + for (size_t k = 0; k < std::min(BlkLen, K - j * BlkLen); ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + vAbsMax = std::max(vAbsMax, fabsf(inputA[input_idx])); + } + + float scale = vAbsMax / 127.f; + float invScale = vAbsMax == 0.f ? 0.f : 127.f / vAbsMax; + scalePtr[i * BlkCount + j] = scale; + + float vSum = 0.f; + for (size_t k = 0; k < BlkLen; ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + size_t output_idx = i * output_lda + j * BlkLen + k; + if (k < std::min(BlkLen, K - j * BlkLen)) { + const auto input_val = inputA[input_idx]; + // Round to nearest, ties away from zero + // float v = std::clamp(std::roundf(input_val * invScale), -128.f, 127.f); + + // Round to nearest, ties to even + float v = std::clamp(std::nearbyint(input_val * invScale), -128.f, 127.f); + + if (quantAUnsigned) { + quantA[output_idx] = static_cast(v + 128.f); + vSum += v + 128.f; + } else { + reinterpret_cast(quantA)[output_idx] = static_cast(v); + vSum += v; + } + } else { + quantA[output_idx] = 0; + } + } + blkSum[i * BlkCount + j] = vSum * scale; + } + } + } + + template + void CheckQuantA(const uint8_t* quantA, const uint8_t* refQuantA) { + constexpr size_t lda = (K + BlkLen - 1) & (~(BlkLen - 1)); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < lda; ++j) { + size_t idx = i * lda + j; + ASSERT_EQ(quantA[idx], refQuantA[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void CheckScale(const float* scale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + size_t idx = i * BlkCount + j; + ASSERT_EQ(scale[idx], refScale[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void TestQuantA() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + const auto* dispatch = GetMlasPlatform().QNBitGemmDispatch; + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackACount = M * Lda; + constexpr size_t ScaleCount = M * BlkCount; + const size_t BufferSize = MlasQNBitGemmBatchWorkspaceSize(M, 1, K, 1, Bits, BlkLen, true, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + + const auto* inputA = inputA_.GetFilledBuffer(M * K, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* workspace = workspace_.GetBuffer(BufferSize, true); + auto* refQuantA = refQuantA_.GetBuffer(PackACount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(ScaleCount, true); + + const size_t Alignment = dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, SQNBIT_CompInt8); + const uintptr_t WorkspaceAddress = reinterpret_cast(workspace); + auto* quantAPtr = reinterpret_cast((WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + auto* scaleAPtr = reinterpret_cast(quantAPtr + PackACount); + auto* blkSumAPtr = scaleAPtr + ScaleCount; + + for (size_t i = 0; i < M; ++i) { + dispatch->QuantizeARowComputeBlkSum_CompInt8(BlkLen, inputA + i * K, K, quantAPtr + i * Lda, scaleAPtr + i * BlkCount, blkSumAPtr + i * BlkCount); + } + + QuantA(inputA, refQuantA, refScale, refBlkSum, isQuantAUnsigned); + CheckQuantA(reinterpret_cast(quantAPtr), refQuantA); + CheckScale(scaleAPtr, refScale); + CheckScale(blkSumAPtr, refBlkSum); + } + + public: + MlasSQ8BitQuantAKernelTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitQuantA"; + } + + void ExecuteShort(void) override { + TestQuantA<1, 16, 16>(); + TestQuantA<1, 1, 32>(); + TestQuantA<1, 1, 64>(); + TestQuantA<1, 1, 128>(); + TestQuantA<1, 1, 256>(); + + TestQuantA<4, 16, 16>(); + TestQuantA<8, 32, 16>(); + TestQuantA<12, 64, 32>(); + TestQuantA<16, 128, 64>(); + + TestQuantA<3, 15, 16>(); + TestQuantA<4, 15, 32>(); + TestQuantA<5, 15, 64>(); + TestQuantA<6, 15, 128>(); + TestQuantA<7, 15, 256>(); + TestQuantA<8, 15, 16>(); + TestQuantA<9, 15, 16>(); + + TestQuantA<3, 17, 16>(); + TestQuantA<4, 17, 32>(); + TestQuantA<5, 17, 64>(); + TestQuantA<6, 17, 128>(); + TestQuantA<7, 17, 256>(); + TestQuantA<8, 17, 16>(); + TestQuantA<9, 17, 16>(); + + TestQuantA<2, 159, 16>(); + TestQuantA<3, 159, 16>(); + TestQuantA<17, 160, 32>(); + TestQuantA<15, 161, 64>(); + TestQuantA<17, 160, 128>(); + TestQuantA<16, 159, 256>(); + + TestQuantA<1, 3072, 16>(); + } +}; + +class MlasSQ8BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer packedBuffer_, workspace_, packedB_, Zp_; + MatrixGuardBuffer A_, B_, C_, ref_, bias_, scale_; + + bool FloatEqual(float v0, float v1, float rtol, float atol) { + return std::abs(v0 - v1) <= std::abs(v1 * rtol) + atol; + } + + template + void MatMul(const float* A, size_t lda, const float* B, const float* bias, float* C, size_t ldc) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = bias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * lda + k]; + float b = B[n * K + k]; + accu += a * b; + } + C[m * ldc + n] = accu; + } + } + } + + template + void Check(const float* target, const float* ref, size_t ldc, float rtol, float atol) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], rtol, atol)) + << " M " << M << " K " << K << " N " << N << " BlkLen " << BlkLen + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestSQ8BitGemmKernel() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkCount * BlkLen; + constexpr size_t lda = ldb; + constexpr size_t ldc = (N + 15) & (~15); + const auto* A = A_.GetFilledBuffer(M * lda, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* B = B_.GetFilledBuffer(K * N, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + auto* inputB = packedB_.GetBuffer(q_data_size_in_bytes, true); + auto* inputScale = scale_.GetBuffer(q_scale_size, true); + auto* inputZp = HasZp ? Zp_.GetBuffer(q_zp_size_in_bytes, true) : nullptr; + + MlasQuantizeBlockwise( + inputB, + inputScale, + inputZp, + B, + BlkLen, + true, + K, + N, + N, + nullptr); + + MlasDequantizeBlockwise( + B, + inputB, + inputScale, + inputZp, + BlkLen, + true, + K, + N, + nullptr); + + size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, HasZp, inputZp, nullptr); + + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, HasZp, nullptr, nullptr); + + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, HasZp, inputZp, nullptr); + + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); + + auto* C = C_.GetBuffer(M * ldc, true); + auto* ref = ref_.GetBuffer(M * ldc, true); + + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = (float)(5 + i); + } + }) + : nullptr; + + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* workspace = workspace_.GetBuffer(workspace_size, true); + + MLAS_QNBIT_GEMM_DATA_PARAMS data; + data.A = A; + data.lda = lda; + data.QuantBDataWorkspace = packedBuffer; + data.PackedQuantBData = packedQuantB.PackedQuantBData; + data.QuantBScale = inputScale; + data.QuantBZeroPoint = inputZp; + data.Bias = bias; + data.C = C; + data.ldc = ldc; + + MlasQNBitGemmBatch(M, N, K, 1, 8, BlkLen, SQNBIT_CompInt8, &data, workspace, nullptr); + + MatMul(A, lda, B, bias, ref, ldc); + Check(C, ref, ldc, 0.01f, 0.02f); + } + + public: + MlasSQ8BitGemmKernelTest() + : seed_(1234), gen_(seed_), distrib_f32_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitGemmKernel"; + } + + template + void Execute(void) { + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + } + + void ExecuteShort(void) override { + Execute<1, 16, 1, 16>(); + Execute<1, 1, 1, 16>(); + Execute<7, 2, 4, 16>(); + Execute<7, 128, 4, 16>(); + Execute<8, 497, 5, 16>(); + Execute<1, 3072, 128, 16>(); + Execute<2, 3072, 128, 16>(); + + Execute<1, 1, 1, 32>(); + Execute<8, 33, 5, 32>(); + Execute<8, 513, 9, 32>(); + Execute<1, 3072, 128, 32>(); + Execute<2, 3072, 128, 32>(); + + Execute<1, 1, 1, 64>(); + Execute<8, 497, 9, 64>(); + Execute<1, 3072, 128, 64>(); + Execute<2, 3072, 128, 64>(); + + Execute<1, 1, 1, 128>(); + Execute<6, 255, 7, 128>(); + Execute<5, 257, 9, 128>(); + Execute<1, 3072, 128, 128>(); + Execute<2, 3072, 128, 128>(); + + Execute<1, 1, 1, 256>(); + Execute<7, 255, 7, 256>(); + Execute<6, 257, 7, 256>(); + Execute<1, 3072, 128, 256>(); + Execute<2, 3072, 128, 256>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_sqnbitgemm.cpp b/tests/unittest/test_sqnbitgemm.cpp index 0710981..91ce359 100644 --- a/tests/unittest/test_sqnbitgemm.cpp +++ b/tests/unittest/test_sqnbitgemm.cpp @@ -18,11 +18,11 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { - case CompFp32: + case SQNBIT_CompFp32: return "Fp32"; - case CompInt8: + case SQNBIT_CompInt8: return "Int8"; default: return "unknown"; @@ -63,16 +63,16 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; + MLAS_QNBIT_GEMM_DATA_PARAMS params; params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; params.ldc = ldc; #ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { params.QuantBDataWorkspace = PackedQuantBDataWorkspace; } #endif @@ -81,7 +81,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -201,7 +201,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -246,9 +246,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { uint8_t* QuantBZeroPoint = nullptr; { size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, - static_cast(K), static_cast(N), - QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); @@ -265,19 +265,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, !Symmetric, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, + GetMlasThreadPool()); } CallGemm(M, N, K, @@ -289,9 +289,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { ComputeType, Threadpool); - if (ComputeType == CompFp32) { + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -324,7 +324,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -341,11 +341,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture #include "test_util.h" -#include "core/mlas/lib/mlasi.h" +#include "mlasi.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_transpose.cpp b/tests/unittest/test_transpose.cpp index 8fa9841..2f446ec 100644 --- a/tests/unittest/test_transpose.cpp +++ b/tests/unittest/test_transpose.cpp @@ -3,12 +3,13 @@ #include "test_util.h" -template +template class MlasTransposeTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; MatrixGuardBuffer BufferOutput; MatrixGuardBuffer BufferOutputReference; + MLAS_THREADPOOL* threadpool_; void Test(size_t M, size_t N) { @@ -16,7 +17,7 @@ class MlasTransposeTest : public MlasTestBase { ElementType* Output = BufferOutput.GetBuffer(M * N); ElementType* OutputReference = BufferOutputReference.GetBuffer(M * N); - MlasTranspose(Input, Output, M, N); + MlasTranspose(Input, Output, M, N, threadpool_); ReferenceTranspose(Input, OutputReference, M, N); ASSERT_EQ(memcmp(Output, OutputReference, M * N * sizeof(ElementType)), 0) << " [" << M << "," << N << "]"; @@ -31,11 +32,23 @@ class MlasTransposeTest : public MlasTestBase { } public: + MlasTransposeTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + static const char* GetTestSuiteName() { - static const std::string suite_name = std::string("Transpose_Size") + std::to_string(int(sizeof(ElementType))); + static const std::string suite_name = std::string("Transpose_") + + GetTypeString() + + std::string(Threaded ? "_Threaded" : "_SingleThread"); return suite_name.c_str(); } + static const std::string GetTypeString() { + if (std::is_same::value) return std::string("FP32"); + if (std::is_same::value) return std::string("U32"); + if (std::is_same::value) return std::string("U16"); + if (std::is_same::value) return std::string("U8"); + return std::string("unknown"); + } + void ExecuteShort(void) override { for (size_t m = 1; m <= 32; m++) { for (size_t n = 1; n <= 32; n++) { @@ -48,9 +61,14 @@ class MlasTransposeTest : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/tests/unittest/test_util.h b/tests/unittest/test_util.h index 9b52ce2..95a4ebe 100644 --- a/tests/unittest/test_util.h +++ b/tests/unittest/test_util.h @@ -7,6 +7,7 @@ #include "gtest/gtest.h" #include +#include #include #include #include @@ -14,8 +15,6 @@ #include #include #include -#include - #if defined(_WIN32) #include #else @@ -37,42 +36,10 @@ #define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0])) #endif -#ifndef MLAS_THROW_EX -#ifdef MLAS_NO_EXCEPTION - -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) -{ -#if defined(__ANDROID__) - __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); -#else - // TODO, consider changing the output of the error message from std::cerr to logging when the - // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr - // output might not be easily accesible on some systems such as mobile - // TODO, see if we need to change the output of the error message from std::cerr to NSLog for - // iOS - std::cerr << msg << std::endl; -#endif -} - -#define MLAS_THROW_EX(ex, what) \ - do { \ - std::string msg = #ex; \ - msg.append(what); \ - MlasPrintFinalMessage(msg); \ - abort(); \ - } while (false) - -#else - -#define MLAS_THROW_EX(ex, ...) throw ex(__VA_ARGS__) - -#endif // MLAS_NO_EXCEPTION -#endif - MLAS_THREADPOOL* GetMlasThreadPool(void); - +#ifdef BUILD_MLAS_NO_ONNXRUNTIME +#include "matrix_buffer.h" +#else template class MatrixGuardBuffer { public: @@ -123,7 +90,11 @@ class MatrixGuardBuffer { #if defined(_WIN32) if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { - MLAS_THROW_EX(std::bad_alloc); +#ifdef BUILD_MLAS_NO_ONNXRUNTIME + abort(); +#else + ORT_THROW_EX(std::bad_alloc); +#endif } #else if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { @@ -151,24 +122,24 @@ class MatrixGuardBuffer { return GetFilledBuffer( Elements, [](T* start, size_t size) { - std::fill_n(start, size, T(0)); + std::fill_n(start, size, T(0.0f)); }); } return GetFilledBuffer( Elements, [](T* start, size_t size) { - constexpr int offset = -21; - constexpr int range = 43; + constexpr float offset = -21.f; + constexpr float range = 43.f; - int FillValue = 11; + float FillValue = 11.f; T* FillAddress = start; for (size_t i = 0; i < size; i++) { auto itemv = FillValue - offset; *FillAddress++ = (T)(itemv); - FillValue += 7; - FillValue %= range; + FillValue += 7.f; + FillValue = FillValue >= range ? FillValue - range : FillValue; } }); } @@ -194,7 +165,7 @@ class MatrixGuardBuffer { size_t _BaseBufferSize; T* _GuardAddress; }; - +#endif class MlasTestBase { public: virtual ~MlasTestBase(void) {}